@@ -1,25 +1,28 @@
-use super::Client;
-use super::*;
-use crate::http::{HttpClient, Request, Response, ServerResponse};
+use crate::{
+ http::{HttpClient, Request, Response, ServerResponse},
+ Client, Connection, Credentials, EstablishConnectionError, UserStore,
+};
+use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
-use gpui::{ModelHandle, TestAppContext};
+use gpui::{executor, ModelHandle, TestAppContext};
use parking_lot::Mutex;
use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
-use std::fmt;
-use std::sync::atomic::Ordering::SeqCst;
-use std::sync::{
- atomic::{AtomicBool, AtomicUsize},
- Arc,
-};
+use std::{fmt, rc::Rc, sync::Arc};
pub struct FakeServer {
peer: Arc<Peer>,
- incoming: Mutex<Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>>,
- connection_id: Mutex<Option<ConnectionId>>,
- forbid_connections: AtomicBool,
- auth_count: AtomicUsize,
- access_token: AtomicUsize,
+ state: Arc<Mutex<FakeServerState>>,
user_id: u64,
+ executor: Rc<executor::Foreground>,
+}
+
+#[derive(Default)]
+struct FakeServerState {
+ incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
+ connection_id: Option<ConnectionId>,
+ forbid_connections: bool,
+ auth_count: usize,
+ access_token: usize,
}
impl FakeServer {
@@ -27,24 +30,22 @@ impl FakeServer {
client_user_id: u64,
client: &mut Arc<Client>,
cx: &TestAppContext,
- ) -> Arc<Self> {
- let server = Arc::new(Self {
+ ) -> Self {
+ let server = Self {
peer: Peer::new(),
- incoming: Default::default(),
- connection_id: Default::default(),
- forbid_connections: Default::default(),
- auth_count: Default::default(),
- access_token: Default::default(),
+ state: Default::default(),
user_id: client_user_id,
- });
+ executor: cx.foreground(),
+ };
Arc::get_mut(client)
.unwrap()
.override_authenticate({
- let server = server.clone();
+ let state = server.state.clone();
move |cx| {
- server.auth_count.fetch_add(1, SeqCst);
- let access_token = server.access_token.load(SeqCst).to_string();
+ let mut state = state.lock();
+ state.auth_count += 1;
+ let access_token = state.access_token.to_string();
cx.spawn(move |_| async move {
Ok(Credentials {
user_id: client_user_id,
@@ -54,12 +55,32 @@ impl FakeServer {
}
})
.override_establish_connection({
- let server = server.clone();
+ let peer = server.peer.clone();
+ let state = server.state.clone();
move |credentials, cx| {
+ let peer = peer.clone();
+ let state = state.clone();
let credentials = credentials.clone();
- cx.spawn({
- let server = server.clone();
- move |cx| async move { server.establish_connection(&credentials, &cx).await }
+ cx.spawn(move |cx| async move {
+ assert_eq!(credentials.user_id, client_user_id);
+
+ if state.lock().forbid_connections {
+ Err(EstablishConnectionError::Other(anyhow!(
+ "server is forbidding connections"
+ )))?
+ }
+
+ if credentials.access_token != state.lock().access_token.to_string() {
+ Err(EstablishConnectionError::Unauthorized)?
+ }
+
+ let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
+ let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
+ cx.background().spawn(io).detach();
+ let mut state = state.lock();
+ state.connection_id = Some(connection_id);
+ state.incoming = Some(incoming);
+ Ok(client_conn)
})
}
});
@@ -73,49 +94,25 @@ impl FakeServer {
pub fn disconnect(&self) {
self.peer.disconnect(self.connection_id());
- self.connection_id.lock().take();
- self.incoming.lock().take();
- }
-
- async fn establish_connection(
- &self,
- credentials: &Credentials,
- cx: &AsyncAppContext,
- ) -> Result<Connection, EstablishConnectionError> {
- assert_eq!(credentials.user_id, self.user_id);
-
- if self.forbid_connections.load(SeqCst) {
- Err(EstablishConnectionError::Other(anyhow!(
- "server is forbidding connections"
- )))?
- }
-
- if credentials.access_token != self.access_token.load(SeqCst).to_string() {
- Err(EstablishConnectionError::Unauthorized)?
- }
-
- let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
- let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
- cx.background().spawn(io).detach();
- *self.incoming.lock() = Some(incoming);
- *self.connection_id.lock() = Some(connection_id);
- Ok(client_conn)
+ let mut state = self.state.lock();
+ state.connection_id.take();
+ state.incoming.take();
}
pub fn auth_count(&self) -> usize {
- self.auth_count.load(SeqCst)
+ self.state.lock().auth_count
}
pub fn roll_access_token(&self) {
- self.access_token.fetch_add(1, SeqCst);
+ self.state.lock().access_token += 1;
}
pub fn forbid_connections(&self) {
- self.forbid_connections.store(true, SeqCst);
+ self.state.lock().forbid_connections = true;
}
pub fn allow_connections(&self) {
- self.forbid_connections.store(false, SeqCst);
+ self.state.lock().forbid_connections = false;
}
pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
@@ -123,14 +120,17 @@ impl FakeServer {
}
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
+ self.executor.start_waiting();
let message = self
- .incoming
+ .state
.lock()
+ .incoming
.as_mut()
.expect("not connected")
.next()
.await
.ok_or_else(|| anyhow!("other half hung up"))?;
+ self.executor.finish_waiting();
let type_name = message.payload_type_name();
Ok(*message
.into_any()
@@ -152,7 +152,7 @@ impl FakeServer {
}
fn connection_id(&self) -> ConnectionId {
- self.connection_id.lock().expect("not connected")
+ self.state.lock().connection_id.expect("not connected")
}
pub async fn build_user_store(