1use crate::{
2 http::{self, HttpClient, Request, Response},
3 Client, Connection, Credentials, EstablishConnectionError, UserStore,
4};
5use anyhow::{anyhow, Result};
6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
7use gpui::{executor, ModelHandle, TestAppContext};
8use parking_lot::Mutex;
9use rpc::{
10 proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
11 ConnectionId, Peer, Receipt, TypedEnvelope,
12};
13use std::{fmt, rc::Rc, sync::Arc};
14
15pub struct FakeServer {
16 peer: Arc<Peer>,
17 state: Arc<Mutex<FakeServerState>>,
18 user_id: u64,
19 executor: Rc<executor::Foreground>,
20}
21
22#[derive(Default)]
23struct FakeServerState {
24 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
25 connection_id: Option<ConnectionId>,
26 forbid_connections: bool,
27 auth_count: usize,
28 access_token: usize,
29}
30
31impl FakeServer {
32 pub async fn for_client(
33 client_user_id: u64,
34 client: &Arc<Client>,
35 cx: &TestAppContext,
36 ) -> Self {
37 let server = Self {
38 peer: Peer::new(),
39 state: Default::default(),
40 user_id: client_user_id,
41 executor: cx.foreground(),
42 };
43
44 client
45 .override_authenticate({
46 let state = Arc::downgrade(&server.state);
47 move |cx| {
48 let state = state.clone();
49 cx.spawn(move |_| async move {
50 let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
51 let mut state = state.lock();
52 state.auth_count += 1;
53 let access_token = state.access_token.to_string();
54 Ok(Credentials {
55 user_id: client_user_id,
56 access_token,
57 })
58 })
59 }
60 })
61 .override_establish_connection({
62 let peer = Arc::downgrade(&server.peer);
63 let state = Arc::downgrade(&server.state);
64 move |credentials, cx| {
65 let peer = peer.clone();
66 let state = state.clone();
67 let credentials = credentials.clone();
68 cx.spawn(move |cx| async move {
69 let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
70 let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
71 if state.lock().forbid_connections {
72 Err(EstablishConnectionError::Other(anyhow!(
73 "server is forbidding connections"
74 )))?
75 }
76
77 assert_eq!(credentials.user_id, client_user_id);
78
79 if credentials.access_token != state.lock().access_token.to_string() {
80 Err(EstablishConnectionError::Unauthorized)?
81 }
82
83 let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
84 let (connection_id, io, incoming) =
85 peer.add_test_connection(server_conn, cx.background());
86 cx.background().spawn(io).detach();
87 let mut state = state.lock();
88 state.connection_id = Some(connection_id);
89 state.incoming = Some(incoming);
90 Ok(client_conn)
91 })
92 }
93 });
94
95 client
96 .authenticate_and_connect(false, &cx.to_async())
97 .await
98 .unwrap();
99
100 server
101 }
102
103 pub fn disconnect(&self) {
104 if self.state.lock().connection_id.is_some() {
105 self.peer.disconnect(self.connection_id());
106 let mut state = self.state.lock();
107 state.connection_id.take();
108 state.incoming.take();
109 }
110 }
111
112 pub fn auth_count(&self) -> usize {
113 self.state.lock().auth_count
114 }
115
116 pub fn roll_access_token(&self) {
117 self.state.lock().access_token += 1;
118 }
119
120 pub fn forbid_connections(&self) {
121 self.state.lock().forbid_connections = true;
122 }
123
124 pub fn allow_connections(&self) {
125 self.state.lock().forbid_connections = false;
126 }
127
128 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
129 self.peer.send(self.connection_id(), message).unwrap();
130 }
131
132 #[allow(clippy::await_holding_lock)]
133 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
134 self.executor.start_waiting();
135
136 loop {
137 let message = self
138 .state
139 .lock()
140 .incoming
141 .as_mut()
142 .expect("not connected")
143 .next()
144 .await
145 .ok_or_else(|| anyhow!("other half hung up"))?;
146 self.executor.finish_waiting();
147 let type_name = message.payload_type_name();
148 let message = message.into_any();
149
150 if message.is::<TypedEnvelope<M>>() {
151 return Ok(*message.downcast().unwrap());
152 }
153
154 if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
155 self.respond(
156 message
157 .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
158 .unwrap()
159 .receipt(),
160 GetPrivateUserInfoResponse {
161 metrics_id: "the-metrics-id".into(),
162 staff: false,
163 },
164 )
165 .await;
166 continue;
167 }
168
169 panic!(
170 "fake server received unexpected message type: {:?}",
171 type_name
172 );
173 }
174 }
175
176 pub async fn respond<T: proto::RequestMessage>(
177 &self,
178 receipt: Receipt<T>,
179 response: T::Response,
180 ) {
181 self.peer.respond(receipt, response).unwrap()
182 }
183
184 fn connection_id(&self) -> ConnectionId {
185 self.state.lock().connection_id.expect("not connected")
186 }
187
188 pub async fn build_user_store(
189 &self,
190 client: Arc<Client>,
191 cx: &mut TestAppContext,
192 ) -> ModelHandle<UserStore> {
193 let http_client = FakeHttpClient::with_404_response();
194 let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
195 assert_eq!(
196 self.receive::<proto::GetUsers>()
197 .await
198 .unwrap()
199 .payload
200 .user_ids,
201 &[self.user_id]
202 );
203 user_store
204 }
205}
206
207impl Drop for FakeServer {
208 fn drop(&mut self) {
209 self.disconnect();
210 }
211}
212
213pub struct FakeHttpClient {
214 handler: Box<
215 dyn 'static
216 + Send
217 + Sync
218 + Fn(Request) -> BoxFuture<'static, Result<Response, http::Error>>,
219 >,
220}
221
222impl FakeHttpClient {
223 pub fn create<Fut, F>(handler: F) -> Arc<dyn HttpClient>
224 where
225 Fut: 'static + Send + Future<Output = Result<Response, http::Error>>,
226 F: 'static + Send + Sync + Fn(Request) -> Fut,
227 {
228 Arc::new(Self {
229 handler: Box::new(move |req| Box::pin(handler(req))),
230 })
231 }
232
233 pub fn with_404_response() -> Arc<dyn HttpClient> {
234 Self::create(|_| async move {
235 Ok(isahc::Response::builder()
236 .status(404)
237 .body(Default::default())
238 .unwrap())
239 })
240 }
241}
242
243impl fmt::Debug for FakeHttpClient {
244 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245 f.debug_struct("FakeHttpClient").finish()
246 }
247}
248
249impl HttpClient for FakeHttpClient {
250 fn send(&self, req: Request) -> BoxFuture<Result<Response, crate::http::Error>> {
251 let future = (self.handler)(req);
252 Box::pin(async move { future.await.map(Into::into) })
253 }
254}