1use crate::{
2 http::{self, HttpClient, Request, Response},
3 Client, Connection, Credentials, EstablishConnectionError, UserStore,
4};
5use anyhow::{anyhow, Result};
6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
7use gpui::{executor, ModelHandle, TestAppContext};
8use parking_lot::Mutex;
9use rpc::{
10 proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
11 ConnectionId, Peer, Receipt, TypedEnvelope,
12};
13use std::{fmt, rc::Rc, sync::Arc};
14
15pub struct FakeServer {
16 peer: Arc<Peer>,
17 state: Arc<Mutex<FakeServerState>>,
18 user_id: u64,
19 executor: Rc<executor::Foreground>,
20}
21
22#[derive(Default)]
23struct FakeServerState {
24 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
25 connection_id: Option<ConnectionId>,
26 forbid_connections: bool,
27 auth_count: usize,
28 access_token: usize,
29}
30
31impl FakeServer {
32 pub async fn for_client(
33 client_user_id: u64,
34 client: &Arc<Client>,
35 cx: &TestAppContext,
36 ) -> Self {
37 let server = Self {
38 peer: Peer::new(),
39 state: Default::default(),
40 user_id: client_user_id,
41 executor: cx.foreground(),
42 };
43
44 client
45 .override_authenticate({
46 let state = Arc::downgrade(&server.state);
47 move |cx| {
48 let state = state.clone();
49 cx.spawn(move |_| async move {
50 let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
51 let mut state = state.lock();
52 state.auth_count += 1;
53 let access_token = state.access_token.to_string();
54 Ok(Credentials {
55 user_id: client_user_id,
56 access_token,
57 })
58 })
59 }
60 })
61 .override_establish_connection({
62 let peer = Arc::downgrade(&server.peer);
63 let state = Arc::downgrade(&server.state);
64 move |credentials, cx| {
65 let peer = peer.clone();
66 let state = state.clone();
67 let credentials = credentials.clone();
68 cx.spawn(move |cx| async move {
69 let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
70 let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
71 if state.lock().forbid_connections {
72 Err(EstablishConnectionError::Other(anyhow!(
73 "server is forbidding connections"
74 )))?
75 }
76
77 assert_eq!(credentials.user_id, client_user_id);
78
79 if credentials.access_token != state.lock().access_token.to_string() {
80 Err(EstablishConnectionError::Unauthorized)?
81 }
82
83 let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
84 let (connection_id, io, incoming) =
85 peer.add_test_connection(server_conn, cx.background());
86 cx.background().spawn(io).detach();
87 {
88 let mut state = state.lock();
89 state.connection_id = Some(connection_id);
90 state.incoming = Some(incoming);
91 }
92 peer.send(
93 connection_id,
94 proto::Hello {
95 peer_id: connection_id.0,
96 },
97 )
98 .unwrap();
99
100 Ok(client_conn)
101 })
102 }
103 });
104
105 client
106 .authenticate_and_connect(false, &cx.to_async())
107 .await
108 .unwrap();
109
110 server
111 }
112
113 pub fn disconnect(&self) {
114 if self.state.lock().connection_id.is_some() {
115 self.peer.disconnect(self.connection_id());
116 let mut state = self.state.lock();
117 state.connection_id.take();
118 state.incoming.take();
119 }
120 }
121
122 pub fn auth_count(&self) -> usize {
123 self.state.lock().auth_count
124 }
125
126 pub fn roll_access_token(&self) {
127 self.state.lock().access_token += 1;
128 }
129
130 pub fn forbid_connections(&self) {
131 self.state.lock().forbid_connections = true;
132 }
133
134 pub fn allow_connections(&self) {
135 self.state.lock().forbid_connections = false;
136 }
137
138 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
139 self.peer.send(self.connection_id(), message).unwrap();
140 }
141
142 #[allow(clippy::await_holding_lock)]
143 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
144 self.executor.start_waiting();
145
146 loop {
147 let message = self
148 .state
149 .lock()
150 .incoming
151 .as_mut()
152 .expect("not connected")
153 .next()
154 .await
155 .ok_or_else(|| anyhow!("other half hung up"))?;
156 self.executor.finish_waiting();
157 let type_name = message.payload_type_name();
158 let message = message.into_any();
159
160 if message.is::<TypedEnvelope<M>>() {
161 return Ok(*message.downcast().unwrap());
162 }
163
164 if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
165 self.respond(
166 message
167 .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
168 .unwrap()
169 .receipt(),
170 GetPrivateUserInfoResponse {
171 metrics_id: "the-metrics-id".into(),
172 staff: false,
173 },
174 )
175 .await;
176 continue;
177 }
178
179 panic!(
180 "fake server received unexpected message type: {:?}",
181 type_name
182 );
183 }
184 }
185
186 pub async fn respond<T: proto::RequestMessage>(
187 &self,
188 receipt: Receipt<T>,
189 response: T::Response,
190 ) {
191 self.peer.respond(receipt, response).unwrap()
192 }
193
194 fn connection_id(&self) -> ConnectionId {
195 self.state.lock().connection_id.expect("not connected")
196 }
197
198 pub async fn build_user_store(
199 &self,
200 client: Arc<Client>,
201 cx: &mut TestAppContext,
202 ) -> ModelHandle<UserStore> {
203 let http_client = FakeHttpClient::with_404_response();
204 let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
205 assert_eq!(
206 self.receive::<proto::GetUsers>()
207 .await
208 .unwrap()
209 .payload
210 .user_ids,
211 &[self.user_id]
212 );
213 user_store
214 }
215}
216
217impl Drop for FakeServer {
218 fn drop(&mut self) {
219 self.disconnect();
220 }
221}
222
223pub struct FakeHttpClient {
224 handler: Box<
225 dyn 'static
226 + Send
227 + Sync
228 + Fn(Request) -> BoxFuture<'static, Result<Response, http::Error>>,
229 >,
230}
231
232impl FakeHttpClient {
233 pub fn create<Fut, F>(handler: F) -> Arc<dyn HttpClient>
234 where
235 Fut: 'static + Send + Future<Output = Result<Response, http::Error>>,
236 F: 'static + Send + Sync + Fn(Request) -> Fut,
237 {
238 Arc::new(Self {
239 handler: Box::new(move |req| Box::pin(handler(req))),
240 })
241 }
242
243 pub fn with_404_response() -> Arc<dyn HttpClient> {
244 Self::create(|_| async move {
245 Ok(isahc::Response::builder()
246 .status(404)
247 .body(Default::default())
248 .unwrap())
249 })
250 }
251}
252
253impl fmt::Debug for FakeHttpClient {
254 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255 f.debug_struct("FakeHttpClient").finish()
256 }
257}
258
259impl HttpClient for FakeHttpClient {
260 fn send(&self, req: Request) -> BoxFuture<Result<Response, crate::http::Error>> {
261 let future = (self.handler)(req);
262 Box::pin(async move { future.await.map(Into::into) })
263 }
264}