1use super::*;
2use std::sync::atomic::Ordering::SeqCst;
3
4use super::Client;
5use gpui::TestAppContext;
6use parking_lot::Mutex;
7use postage::{mpsc, prelude::Stream};
8use std::sync::{
9 atomic::{AtomicBool, AtomicUsize},
10 Arc,
11};
12use zrpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
13
14pub struct FakeServer {
15 peer: Arc<Peer>,
16 incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
17 connection_id: Mutex<Option<ConnectionId>>,
18 forbid_connections: AtomicBool,
19 auth_count: AtomicUsize,
20 access_token: AtomicUsize,
21 user_id: u64,
22}
23
24impl FakeServer {
25 pub async fn for_client(
26 client_user_id: u64,
27 client: &mut Arc<Client>,
28 cx: &TestAppContext,
29 ) -> Arc<Self> {
30 let server = Arc::new(Self {
31 peer: Peer::new(),
32 incoming: Default::default(),
33 connection_id: Default::default(),
34 forbid_connections: Default::default(),
35 auth_count: Default::default(),
36 access_token: Default::default(),
37 user_id: client_user_id,
38 });
39
40 Arc::get_mut(client)
41 .unwrap()
42 .override_authenticate({
43 let server = server.clone();
44 move |cx| {
45 server.auth_count.fetch_add(1, SeqCst);
46 let access_token = server.access_token.load(SeqCst).to_string();
47 cx.spawn(move |_| async move {
48 Ok(Credentials {
49 user_id: client_user_id,
50 access_token,
51 })
52 })
53 }
54 })
55 .override_establish_connection({
56 let server = server.clone();
57 move |credentials, cx| {
58 let credentials = credentials.clone();
59 cx.spawn({
60 let server = server.clone();
61 move |cx| async move { server.establish_connection(&credentials, &cx).await }
62 })
63 }
64 });
65
66 client
67 .authenticate_and_connect(&cx.to_async())
68 .await
69 .unwrap();
70 server
71 }
72
73 pub async fn disconnect(&self) {
74 self.peer.disconnect(self.connection_id()).await;
75 self.connection_id.lock().take();
76 self.incoming.lock().take();
77 }
78
79 async fn establish_connection(
80 &self,
81 credentials: &Credentials,
82 cx: &AsyncAppContext,
83 ) -> Result<Connection, EstablishConnectionError> {
84 assert_eq!(credentials.user_id, self.user_id);
85
86 if self.forbid_connections.load(SeqCst) {
87 Err(EstablishConnectionError::Other(anyhow!(
88 "server is forbidding connections"
89 )))?
90 }
91
92 if credentials.access_token != self.access_token.load(SeqCst).to_string() {
93 Err(EstablishConnectionError::Unauthorized)?
94 }
95
96 let (client_conn, server_conn, _) = Connection::in_memory();
97 let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
98 cx.background().spawn(io).detach();
99 *self.incoming.lock() = Some(incoming);
100 *self.connection_id.lock() = Some(connection_id);
101 Ok(client_conn)
102 }
103
104 pub fn auth_count(&self) -> usize {
105 self.auth_count.load(SeqCst)
106 }
107
108 pub fn roll_access_token(&self) {
109 self.access_token.fetch_add(1, SeqCst);
110 }
111
112 pub fn forbid_connections(&self) {
113 self.forbid_connections.store(true, SeqCst);
114 }
115
116 pub fn allow_connections(&self) {
117 self.forbid_connections.store(false, SeqCst);
118 }
119
120 pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
121 self.peer.send(self.connection_id(), message).await.unwrap();
122 }
123
124 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
125 let message = self
126 .incoming
127 .lock()
128 .as_mut()
129 .expect("not connected")
130 .recv()
131 .await
132 .ok_or_else(|| anyhow!("other half hung up"))?;
133 let type_name = message.payload_type_name();
134 Ok(*message
135 .into_any()
136 .downcast::<TypedEnvelope<M>>()
137 .unwrap_or_else(|_| {
138 panic!(
139 "fake server received unexpected message type: {:?}",
140 type_name
141 );
142 }))
143 }
144
145 pub async fn respond<T: proto::RequestMessage>(
146 &self,
147 receipt: Receipt<T>,
148 response: T::Response,
149 ) {
150 self.peer.respond(receipt, response).await.unwrap()
151 }
152
153 fn connection_id(&self) -> ConnectionId {
154 self.connection_id.lock().expect("not connected")
155 }
156}