test.rs

  1use super::Client;
  2use super::*;
  3use crate::http::{HttpClient, Request, Response, ServerResponse};
  4use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
  5use gpui::{ModelHandle, TestAppContext};
  6use parking_lot::Mutex;
  7use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
  8use std::fmt;
  9use std::sync::atomic::Ordering::SeqCst;
 10use std::sync::{
 11    atomic::{AtomicBool, AtomicUsize},
 12    Arc,
 13};
 14
 15pub struct FakeServer {
 16    peer: Arc<Peer>,
 17    incoming: Mutex<Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>>,
 18    connection_id: Mutex<Option<ConnectionId>>,
 19    forbid_connections: AtomicBool,
 20    auth_count: AtomicUsize,
 21    access_token: AtomicUsize,
 22    user_id: u64,
 23}
 24
 25impl FakeServer {
 26    pub async fn for_client(
 27        client_user_id: u64,
 28        client: &mut Arc<Client>,
 29        cx: &TestAppContext,
 30    ) -> Arc<Self> {
 31        let server = Arc::new(Self {
 32            peer: Peer::new(),
 33            incoming: Default::default(),
 34            connection_id: Default::default(),
 35            forbid_connections: Default::default(),
 36            auth_count: Default::default(),
 37            access_token: Default::default(),
 38            user_id: client_user_id,
 39        });
 40
 41        Arc::get_mut(client)
 42            .unwrap()
 43            .override_authenticate({
 44                let server = server.clone();
 45                move |cx| {
 46                    server.auth_count.fetch_add(1, SeqCst);
 47                    let access_token = server.access_token.load(SeqCst).to_string();
 48                    cx.spawn(move |_| async move {
 49                        Ok(Credentials {
 50                            user_id: client_user_id,
 51                            access_token,
 52                        })
 53                    })
 54                }
 55            })
 56            .override_establish_connection({
 57                let server = server.clone();
 58                move |credentials, cx| {
 59                    let credentials = credentials.clone();
 60                    cx.spawn({
 61                        let server = server.clone();
 62                        move |cx| async move { server.establish_connection(&credentials, &cx).await }
 63                    })
 64                }
 65            });
 66
 67        client
 68            .authenticate_and_connect(&cx.to_async())
 69            .await
 70            .unwrap();
 71        server
 72    }
 73
 74    pub fn disconnect(&self) {
 75        self.peer.disconnect(self.connection_id());
 76        self.connection_id.lock().take();
 77        self.incoming.lock().take();
 78    }
 79
 80    async fn establish_connection(
 81        &self,
 82        credentials: &Credentials,
 83        cx: &AsyncAppContext,
 84    ) -> Result<Connection, EstablishConnectionError> {
 85        assert_eq!(credentials.user_id, self.user_id);
 86
 87        if self.forbid_connections.load(SeqCst) {
 88            Err(EstablishConnectionError::Other(anyhow!(
 89                "server is forbidding connections"
 90            )))?
 91        }
 92
 93        if credentials.access_token != self.access_token.load(SeqCst).to_string() {
 94            Err(EstablishConnectionError::Unauthorized)?
 95        }
 96
 97        let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
 98        let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
 99        cx.background().spawn(io).detach();
100        *self.incoming.lock() = Some(incoming);
101        *self.connection_id.lock() = Some(connection_id);
102        Ok(client_conn)
103    }
104
105    pub fn auth_count(&self) -> usize {
106        self.auth_count.load(SeqCst)
107    }
108
109    pub fn roll_access_token(&self) {
110        self.access_token.fetch_add(1, SeqCst);
111    }
112
113    pub fn forbid_connections(&self) {
114        self.forbid_connections.store(true, SeqCst);
115    }
116
117    pub fn allow_connections(&self) {
118        self.forbid_connections.store(false, SeqCst);
119    }
120
121    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
122        self.peer.send(self.connection_id(), message).unwrap();
123    }
124
125    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
126        let message = self
127            .incoming
128            .lock()
129            .as_mut()
130            .expect("not connected")
131            .next()
132            .await
133            .ok_or_else(|| anyhow!("other half hung up"))?;
134        let type_name = message.payload_type_name();
135        Ok(*message
136            .into_any()
137            .downcast::<TypedEnvelope<M>>()
138            .unwrap_or_else(|_| {
139                panic!(
140                    "fake server received unexpected message type: {:?}",
141                    type_name
142                );
143            }))
144    }
145
146    pub async fn respond<T: proto::RequestMessage>(
147        &self,
148        receipt: Receipt<T>,
149        response: T::Response,
150    ) {
151        self.peer.respond(receipt, response).unwrap()
152    }
153
154    fn connection_id(&self) -> ConnectionId {
155        self.connection_id.lock().expect("not connected")
156    }
157
158    pub async fn build_user_store(
159        &self,
160        client: Arc<Client>,
161        cx: &mut TestAppContext,
162    ) -> ModelHandle<UserStore> {
163        let http_client = FakeHttpClient::with_404_response();
164        let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
165        assert_eq!(
166            self.receive::<proto::GetUsers>()
167                .await
168                .unwrap()
169                .payload
170                .user_ids,
171            &[self.user_id]
172        );
173        user_store
174    }
175}
176
177pub struct FakeHttpClient {
178    handler:
179        Box<dyn 'static + Send + Sync + Fn(Request) -> BoxFuture<'static, Result<ServerResponse>>>,
180}
181
182impl FakeHttpClient {
183    pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
184    where
185        Fut: 'static + Send + Future<Output = Result<ServerResponse>>,
186        F: 'static + Send + Sync + Fn(Request) -> Fut,
187    {
188        Arc::new(Self {
189            handler: Box::new(move |req| Box::pin(handler(req))),
190        })
191    }
192
193    pub fn with_404_response() -> Arc<dyn HttpClient> {
194        Self::new(|_| async move { Ok(ServerResponse::new(404)) })
195    }
196}
197
198impl fmt::Debug for FakeHttpClient {
199    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200        f.debug_struct("FakeHttpClient").finish()
201    }
202}
203
204impl HttpClient for FakeHttpClient {
205    fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response>> {
206        let future = (self.handler)(req);
207        Box::pin(async move { future.await.map(Into::into) })
208    }
209}