lib.rs

  1#[cfg(any(test, feature = "test-support"))]
  2pub mod test;
  3
  4use anyhow::{anyhow, Context, Result};
  5use async_recursion::async_recursion;
  6use async_tungstenite::tungstenite::{
  7    error::Error as WebsocketError,
  8    http::{Request, StatusCode},
  9};
 10use gpui::{AsyncAppContext, Entity, ModelContext, Task};
 11use lazy_static::lazy_static;
 12use parking_lot::RwLock;
 13use postage::{prelude::Stream, watch};
 14use rand::prelude::*;
 15use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
 16use std::{
 17    any::TypeId,
 18    collections::HashMap,
 19    convert::TryFrom,
 20    fmt::Write as _,
 21    future::Future,
 22    sync::{Arc, Weak},
 23    time::{Duration, Instant},
 24};
 25use surf::Url;
 26use thiserror::Error;
 27use util::ResultExt;
 28
 29pub use rpc::*;
 30
 31lazy_static! {
 32    static ref ZED_SERVER_URL: String =
 33        std::env::var("ZED_SERVER_URL").unwrap_or("https://zed.dev:443".to_string());
 34    static ref IMPERSONATE_LOGIN: Option<String> = std::env::var("ZED_IMPERSONATE")
 35        .ok()
 36        .and_then(|s| if s.is_empty() { None } else { Some(s) });
 37}
 38
 39pub struct Client {
 40    peer: Arc<Peer>,
 41    state: RwLock<ClientState>,
 42    authenticate:
 43        Option<Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>>>,
 44    establish_connection: Option<
 45        Box<
 46            dyn 'static
 47                + Send
 48                + Sync
 49                + Fn(
 50                    &Credentials,
 51                    &AsyncAppContext,
 52                ) -> Task<Result<Connection, EstablishConnectionError>>,
 53        >,
 54    >,
 55}
 56
 57#[derive(Error, Debug)]
 58pub enum EstablishConnectionError {
 59    #[error("upgrade required")]
 60    UpgradeRequired,
 61    #[error("unauthorized")]
 62    Unauthorized,
 63    #[error("{0}")]
 64    Other(#[from] anyhow::Error),
 65    #[error("{0}")]
 66    Io(#[from] std::io::Error),
 67    #[error("{0}")]
 68    Http(#[from] async_tungstenite::tungstenite::http::Error),
 69}
 70
 71impl From<WebsocketError> for EstablishConnectionError {
 72    fn from(error: WebsocketError) -> Self {
 73        if let WebsocketError::Http(response) = &error {
 74            match response.status() {
 75                StatusCode::UNAUTHORIZED => return EstablishConnectionError::Unauthorized,
 76                StatusCode::UPGRADE_REQUIRED => return EstablishConnectionError::UpgradeRequired,
 77                _ => {}
 78            }
 79        }
 80        EstablishConnectionError::Other(error.into())
 81    }
 82}
 83
 84impl EstablishConnectionError {
 85    pub fn other(error: impl Into<anyhow::Error> + Send + Sync) -> Self {
 86        Self::Other(error.into())
 87    }
 88}
 89
 90#[derive(Copy, Clone, Debug)]
 91pub enum Status {
 92    SignedOut,
 93    UpgradeRequired,
 94    Authenticating,
 95    Connecting,
 96    ConnectionError,
 97    Connected { connection_id: ConnectionId },
 98    ConnectionLost,
 99    Reauthenticating,
100    Reconnecting,
101    ReconnectionError { next_reconnection: Instant },
102}
103
104struct ClientState {
105    credentials: Option<Credentials>,
106    status: (watch::Sender<Status>, watch::Receiver<Status>),
107    entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
108    model_handlers: HashMap<
109        (TypeId, u64),
110        Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
111    >,
112    _maintain_connection: Option<Task<()>>,
113    heartbeat_interval: Duration,
114}
115
116#[derive(Clone)]
117pub struct Credentials {
118    pub user_id: u64,
119    pub access_token: String,
120}
121
122impl Default for ClientState {
123    fn default() -> Self {
124        Self {
125            credentials: None,
126            status: watch::channel_with(Status::SignedOut),
127            entity_id_extractors: Default::default(),
128            model_handlers: Default::default(),
129            _maintain_connection: None,
130            heartbeat_interval: Duration::from_secs(5),
131        }
132    }
133}
134
135pub struct Subscription {
136    client: Weak<Client>,
137    id: (TypeId, u64),
138}
139
140impl Drop for Subscription {
141    fn drop(&mut self) {
142        if let Some(client) = self.client.upgrade() {
143            drop(
144                client
145                    .state
146                    .write()
147                    .model_handlers
148                    .remove(&self.id)
149                    .unwrap(),
150            );
151        }
152    }
153}
154
155impl Client {
156    pub fn new() -> Arc<Self> {
157        Arc::new(Self {
158            peer: Peer::new(),
159            state: Default::default(),
160            authenticate: None,
161            establish_connection: None,
162        })
163    }
164
165    #[cfg(any(test, feature = "test-support"))]
166    pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
167    where
168        F: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<Credentials>>,
169    {
170        self.authenticate = Some(Box::new(authenticate));
171        self
172    }
173
174    #[cfg(any(test, feature = "test-support"))]
175    pub fn override_establish_connection<F>(&mut self, connect: F) -> &mut Self
176    where
177        F: 'static
178            + Send
179            + Sync
180            + Fn(&Credentials, &AsyncAppContext) -> Task<Result<Connection, EstablishConnectionError>>,
181    {
182        self.establish_connection = Some(Box::new(connect));
183        self
184    }
185
186    pub fn user_id(&self) -> Option<u64> {
187        self.state
188            .read()
189            .credentials
190            .as_ref()
191            .map(|credentials| credentials.user_id)
192    }
193
194    pub fn status(&self) -> watch::Receiver<Status> {
195        self.state.read().status.1.clone()
196    }
197
198    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
199        let mut state = self.state.write();
200        *state.status.0.borrow_mut() = status;
201
202        match status {
203            Status::Connected { .. } => {
204                let heartbeat_interval = state.heartbeat_interval;
205                let this = self.clone();
206                let foreground = cx.foreground();
207                state._maintain_connection = Some(cx.foreground().spawn(async move {
208                    loop {
209                        foreground.timer(heartbeat_interval).await;
210                        let _ = this.request(proto::Ping {}).await;
211                    }
212                }));
213            }
214            Status::ConnectionLost => {
215                let this = self.clone();
216                let foreground = cx.foreground();
217                let heartbeat_interval = state.heartbeat_interval;
218                state._maintain_connection = Some(cx.spawn(|cx| async move {
219                    let mut rng = StdRng::from_entropy();
220                    let mut delay = Duration::from_millis(100);
221                    while let Err(error) = this.authenticate_and_connect(&cx).await {
222                        log::error!("failed to connect {}", error);
223                        this.set_status(
224                            Status::ReconnectionError {
225                                next_reconnection: Instant::now() + delay,
226                            },
227                            &cx,
228                        );
229                        foreground.timer(delay).await;
230                        delay = delay
231                            .mul_f32(rng.gen_range(1.0..=2.0))
232                            .min(heartbeat_interval);
233                    }
234                }));
235            }
236            Status::SignedOut | Status::UpgradeRequired => {
237                state._maintain_connection.take();
238            }
239            _ => {}
240        }
241    }
242
243    pub fn subscribe<T, M, F>(
244        self: &Arc<Self>,
245        cx: &mut ModelContext<M>,
246        mut handler: F,
247    ) -> Subscription
248    where
249        T: EnvelopedMessage,
250        M: Entity,
251        F: 'static
252            + Send
253            + Sync
254            + FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
255    {
256        let subscription_id = (TypeId::of::<T>(), Default::default());
257        let client = self.clone();
258        let mut state = self.state.write();
259        let model = cx.handle().downgrade();
260        let prev_extractor = state
261            .entity_id_extractors
262            .insert(subscription_id.0, Box::new(|_| Default::default()));
263        if prev_extractor.is_some() {
264            panic!("registered a handler for the same entity twice")
265        }
266
267        state.model_handlers.insert(
268            subscription_id,
269            Box::new(move |envelope, cx| {
270                if let Some(model) = model.upgrade(cx) {
271                    let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
272                    model.update(cx, |model, cx| {
273                        if let Err(error) = handler(model, *envelope, client.clone(), cx) {
274                            log::error!("error handling message: {}", error)
275                        }
276                    });
277                }
278            }),
279        );
280
281        Subscription {
282            client: Arc::downgrade(self),
283            id: subscription_id,
284        }
285    }
286
287    pub fn subscribe_to_entity<T, M, F>(
288        self: &Arc<Self>,
289        remote_id: u64,
290        cx: &mut ModelContext<M>,
291        mut handler: F,
292    ) -> Subscription
293    where
294        T: EntityMessage,
295        M: Entity,
296        F: 'static
297            + Send
298            + Sync
299            + FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
300    {
301        let subscription_id = (TypeId::of::<T>(), remote_id);
302        let client = self.clone();
303        let mut state = self.state.write();
304        let model = cx.handle().downgrade();
305        state
306            .entity_id_extractors
307            .entry(subscription_id.0)
308            .or_insert_with(|| {
309                Box::new(|envelope| {
310                    let envelope = envelope
311                        .as_any()
312                        .downcast_ref::<TypedEnvelope<T>>()
313                        .unwrap();
314                    envelope.payload.remote_entity_id()
315                })
316            });
317        let prev_handler = state.model_handlers.insert(
318            subscription_id,
319            Box::new(move |envelope, cx| {
320                if let Some(model) = model.upgrade(cx) {
321                    let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
322                    model.update(cx, |model, cx| {
323                        if let Err(error) = handler(model, *envelope, client.clone(), cx) {
324                            log::error!("error handling message: {}", error)
325                        }
326                    });
327                }
328            }),
329        );
330        if prev_handler.is_some() {
331            panic!("registered a handler for the same entity twice")
332        }
333
334        Subscription {
335            client: Arc::downgrade(self),
336            id: subscription_id,
337        }
338    }
339
340    #[async_recursion(?Send)]
341    pub async fn authenticate_and_connect(
342        self: &Arc<Self>,
343        cx: &AsyncAppContext,
344    ) -> anyhow::Result<()> {
345        let was_disconnected = match *self.status().borrow() {
346            Status::SignedOut => true,
347            Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
348                false
349            }
350            Status::Connected { .. }
351            | Status::Connecting { .. }
352            | Status::Reconnecting { .. }
353            | Status::Authenticating
354            | Status::Reauthenticating => return Ok(()),
355            Status::UpgradeRequired => return Err(EstablishConnectionError::UpgradeRequired)?,
356        };
357
358        if was_disconnected {
359            self.set_status(Status::Authenticating, cx);
360        } else {
361            self.set_status(Status::Reauthenticating, cx)
362        }
363
364        let mut used_keychain = false;
365        let credentials = self.state.read().credentials.clone();
366        let credentials = if let Some(credentials) = credentials {
367            credentials
368        } else if let Some(credentials) = read_credentials_from_keychain(cx) {
369            used_keychain = true;
370            credentials
371        } else {
372            let credentials = match self.authenticate(&cx).await {
373                Ok(credentials) => credentials,
374                Err(err) => {
375                    self.set_status(Status::ConnectionError, cx);
376                    return Err(err);
377                }
378            };
379            credentials
380        };
381
382        if was_disconnected {
383            self.set_status(Status::Connecting, cx);
384        } else {
385            self.set_status(Status::Reconnecting, cx);
386        }
387
388        match self.establish_connection(&credentials, cx).await {
389            Ok(conn) => {
390                log::info!("connected to rpc address {}", *ZED_SERVER_URL);
391                self.state.write().credentials = Some(credentials.clone());
392                if !used_keychain && IMPERSONATE_LOGIN.is_none() {
393                    write_credentials_to_keychain(&credentials, cx).log_err();
394                }
395                self.set_connection(conn, cx).await;
396                Ok(())
397            }
398            Err(EstablishConnectionError::Unauthorized) => {
399                self.state.write().credentials.take();
400                if used_keychain {
401                    cx.platform().delete_credentials(&ZED_SERVER_URL).log_err();
402                    self.set_status(Status::SignedOut, cx);
403                    self.authenticate_and_connect(cx).await
404                } else {
405                    self.set_status(Status::ConnectionError, cx);
406                    Err(EstablishConnectionError::Unauthorized)?
407                }
408            }
409            Err(EstablishConnectionError::UpgradeRequired) => {
410                self.set_status(Status::UpgradeRequired, cx);
411                Err(EstablishConnectionError::UpgradeRequired)?
412            }
413            Err(error) => {
414                self.set_status(Status::ConnectionError, cx);
415                Err(error)?
416            }
417        }
418    }
419
420    async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
421        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
422        cx.foreground()
423            .spawn({
424                let mut cx = cx.clone();
425                let this = self.clone();
426                async move {
427                    while let Some(message) = incoming.recv().await {
428                        let mut state = this.state.write();
429                        if let Some(extract_entity_id) =
430                            state.entity_id_extractors.get(&message.payload_type_id())
431                        {
432                            let payload_type_id = message.payload_type_id();
433                            let entity_id = (extract_entity_id)(message.as_ref());
434                            let handler_key = (payload_type_id, entity_id);
435                            if let Some(mut handler) = state.model_handlers.remove(&handler_key) {
436                                drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
437                                let start_time = Instant::now();
438                                log::info!("RPC client message {}", message.payload_type_name());
439                                (handler)(message, &mut cx);
440                                log::info!(
441                                    "RPC message handled. duration:{:?}",
442                                    start_time.elapsed()
443                                );
444                                this.state
445                                    .write()
446                                    .model_handlers
447                                    .insert(handler_key, handler);
448                            } else {
449                                log::info!("unhandled message {}", message.payload_type_name());
450                            }
451                        } else {
452                            log::info!("unhandled message {}", message.payload_type_name());
453                        }
454                    }
455                }
456            })
457            .detach();
458
459        self.set_status(Status::Connected { connection_id }, cx);
460
461        let handle_io = cx.background().spawn(handle_io);
462        let this = self.clone();
463        let cx = cx.clone();
464        cx.foreground()
465            .spawn(async move {
466                match handle_io.await {
467                    Ok(()) => this.set_status(Status::SignedOut, &cx),
468                    Err(err) => {
469                        log::error!("connection error: {:?}", err);
470                        this.set_status(Status::ConnectionLost, &cx);
471                    }
472                }
473            })
474            .detach();
475    }
476
477    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<Credentials>> {
478        if let Some(callback) = self.authenticate.as_ref() {
479            callback(cx)
480        } else {
481            self.authenticate_with_browser(cx)
482        }
483    }
484
485    fn establish_connection(
486        self: &Arc<Self>,
487        credentials: &Credentials,
488        cx: &AsyncAppContext,
489    ) -> Task<Result<Connection, EstablishConnectionError>> {
490        if let Some(callback) = self.establish_connection.as_ref() {
491            callback(credentials, cx)
492        } else {
493            self.establish_websocket_connection(credentials, cx)
494        }
495    }
496
497    fn establish_websocket_connection(
498        self: &Arc<Self>,
499        credentials: &Credentials,
500        cx: &AsyncAppContext,
501    ) -> Task<Result<Connection, EstablishConnectionError>> {
502        let request = Request::builder()
503            .header(
504                "Authorization",
505                format!("{} {}", credentials.user_id, credentials.access_token),
506            )
507            .header("X-Zed-Protocol-Version", rpc::PROTOCOL_VERSION);
508        cx.background().spawn(async move {
509            if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
510                let stream = smol::net::TcpStream::connect(host).await?;
511                let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
512                let (stream, _) =
513                    async_tungstenite::async_tls::client_async_tls(request, stream).await?;
514                Ok(Connection::new(stream))
515            } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
516                let stream = smol::net::TcpStream::connect(host).await?;
517                let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
518                let (stream, _) = async_tungstenite::client_async(request, stream).await?;
519                Ok(Connection::new(stream))
520            } else {
521                Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?
522            }
523        })
524    }
525
526    pub fn authenticate_with_browser(
527        self: &Arc<Self>,
528        cx: &AsyncAppContext,
529    ) -> Task<Result<Credentials>> {
530        let platform = cx.platform();
531        let executor = cx.background();
532        executor.clone().spawn(async move {
533            // Generate a pair of asymmetric encryption keys. The public key will be used by the
534            // zed server to encrypt the user's access token, so that it can'be intercepted by
535            // any other app running on the user's device.
536            let (public_key, private_key) =
537                rpc::auth::keypair().expect("failed to generate keypair for auth");
538            let public_key_string =
539                String::try_from(public_key).expect("failed to serialize public key for auth");
540
541            // Start an HTTP server to receive the redirect from Zed's sign-in page.
542            let server = tiny_http::Server::http("127.0.0.1:0").expect("failed to find open port");
543            let port = server.server_addr().port();
544
545            // Open the Zed sign-in page in the user's browser, with query parameters that indicate
546            // that the user is signing in from a Zed app running on the same device.
547            let mut url = format!(
548                "{}/sign_in?native_app_port={}&native_app_public_key={}",
549                *ZED_SERVER_URL, port, public_key_string
550            );
551
552            if let Some(impersonate_login) = IMPERSONATE_LOGIN.as_ref() {
553                log::info!("impersonating user @{}", impersonate_login);
554                write!(&mut url, "&impersonate={}", impersonate_login).unwrap();
555            }
556
557            platform.open_url(&url);
558
559            // Receive the HTTP request from the user's browser. Retrieve the user id and encrypted
560            // access token from the query params.
561            //
562            // TODO - Avoid ever starting more than one HTTP server. Maybe switch to using a
563            // custom URL scheme instead of this local HTTP server.
564            let (user_id, access_token) = executor
565                .spawn(async move {
566                    if let Some(req) = server.recv_timeout(Duration::from_secs(10 * 60))? {
567                        let path = req.url();
568                        let mut user_id = None;
569                        let mut access_token = None;
570                        let url = Url::parse(&format!("http://example.com{}", path))
571                            .context("failed to parse login notification url")?;
572                        for (key, value) in url.query_pairs() {
573                            if key == "access_token" {
574                                access_token = Some(value.to_string());
575                            } else if key == "user_id" {
576                                user_id = Some(value.to_string());
577                            }
578                        }
579                        req.respond(
580                            tiny_http::Response::from_string(LOGIN_RESPONSE).with_header(
581                                tiny_http::Header::from_bytes("Content-Type", "text/html").unwrap(),
582                            ),
583                        )
584                        .context("failed to respond to login http request")?;
585                        Ok((
586                            user_id.ok_or_else(|| anyhow!("missing user_id parameter"))?,
587                            access_token
588                                .ok_or_else(|| anyhow!("missing access_token parameter"))?,
589                        ))
590                    } else {
591                        Err(anyhow!("didn't receive login redirect"))
592                    }
593                })
594                .await?;
595
596            let access_token = private_key
597                .decrypt_string(&access_token)
598                .context("failed to decrypt access token")?;
599            platform.activate(true);
600
601            Ok(Credentials {
602                user_id: user_id.parse()?,
603                access_token,
604            })
605        })
606    }
607
608    pub async fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
609        let conn_id = self.connection_id()?;
610        self.peer.disconnect(conn_id).await;
611        self.set_status(Status::SignedOut, cx);
612        Ok(())
613    }
614
615    fn connection_id(&self) -> Result<ConnectionId> {
616        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
617            Ok(connection_id)
618        } else {
619            Err(anyhow!("not connected"))
620        }
621    }
622
623    pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
624        self.peer.send(self.connection_id()?, message).await
625    }
626
627    pub async fn request<T: RequestMessage>(&self, request: T) -> Result<T::Response> {
628        self.peer.request(self.connection_id()?, request).await
629    }
630
631    pub fn respond<T: RequestMessage>(
632        &self,
633        receipt: Receipt<T>,
634        response: T::Response,
635    ) -> impl Future<Output = Result<()>> {
636        self.peer.respond(receipt, response)
637    }
638}
639
640fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option<Credentials> {
641    if IMPERSONATE_LOGIN.is_some() {
642        return None;
643    }
644
645    let (user_id, access_token) = cx
646        .platform()
647        .read_credentials(&ZED_SERVER_URL)
648        .log_err()
649        .flatten()?;
650    Some(Credentials {
651        user_id: user_id.parse().ok()?,
652        access_token: String::from_utf8(access_token).ok()?,
653    })
654}
655
656fn write_credentials_to_keychain(credentials: &Credentials, cx: &AsyncAppContext) -> Result<()> {
657    cx.platform().write_credentials(
658        &ZED_SERVER_URL,
659        &credentials.user_id.to_string(),
660        credentials.access_token.as_bytes(),
661    )
662}
663
664const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
665
666pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
667    format!("{}{}/{}", WORKTREE_URL_PREFIX, id, access_token)
668}
669
670pub fn decode_worktree_url(url: &str) -> Option<(u64, String)> {
671    let path = url.trim().strip_prefix(WORKTREE_URL_PREFIX)?;
672    let mut parts = path.split('/');
673    let id = parts.next()?.parse::<u64>().ok()?;
674    let access_token = parts.next()?;
675    if access_token.is_empty() {
676        return None;
677    }
678    Some((id, access_token.to_string()))
679}
680
681const LOGIN_RESPONSE: &'static str = "
682<!DOCTYPE html>
683<html>
684<script>window.close();</script>
685</html>
686";
687
688#[cfg(test)]
689mod tests {
690    use super::*;
691    use crate::test::FakeServer;
692    use gpui::TestAppContext;
693
694    #[gpui::test(iterations = 10)]
695    async fn test_heartbeat(cx: TestAppContext) {
696        cx.foreground().forbid_parking();
697
698        let user_id = 5;
699        let mut client = Client::new();
700        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
701
702        cx.foreground().advance_clock(Duration::from_secs(10));
703        let ping = server.receive::<proto::Ping>().await.unwrap();
704        server.respond(ping.receipt(), proto::Ack {}).await;
705
706        cx.foreground().advance_clock(Duration::from_secs(10));
707        let ping = server.receive::<proto::Ping>().await.unwrap();
708        server.respond(ping.receipt(), proto::Ack {}).await;
709
710        client.disconnect(&cx.to_async()).await.unwrap();
711        assert!(server.receive::<proto::Ping>().await.is_err());
712    }
713
714    #[gpui::test(iterations = 10)]
715    async fn test_reconnection(cx: TestAppContext) {
716        cx.foreground().forbid_parking();
717
718        let user_id = 5;
719        let mut client = Client::new();
720        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
721        let mut status = client.status();
722        assert!(matches!(
723            status.recv().await,
724            Some(Status::Connected { .. })
725        ));
726        assert_eq!(server.auth_count(), 1);
727
728        server.forbid_connections();
729        server.disconnect().await;
730        while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
731
732        server.allow_connections();
733        cx.foreground().advance_clock(Duration::from_secs(10));
734        while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
735        assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
736
737        server.forbid_connections();
738        server.disconnect().await;
739        while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
740
741        // Clear cached credentials after authentication fails
742        server.roll_access_token();
743        server.allow_connections();
744        cx.foreground().advance_clock(Duration::from_secs(10));
745        assert_eq!(server.auth_count(), 1);
746        cx.foreground().advance_clock(Duration::from_secs(10));
747        while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
748        assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
749    }
750
751    #[test]
752    fn test_encode_and_decode_worktree_url() {
753        let url = encode_worktree_url(5, "deadbeef");
754        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
755        assert_eq!(
756            decode_worktree_url(&format!("\n {}\t", url)),
757            Some((5, "deadbeef".to_string()))
758        );
759        assert_eq!(decode_worktree_url("not://the-right-format"), None);
760    }
761}