1use crate::{
2 http::{self, HttpClient, Request, Response},
3 Client, Connection, Credentials, EstablishConnectionError, UserStore,
4};
5use anyhow::{anyhow, Result};
6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
7use gpui::{executor, ModelHandle, TestAppContext};
8use parking_lot::Mutex;
9use rpc::{
10 proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
11 ConnectionId, Peer, Receipt, TypedEnvelope,
12};
13use std::{fmt, rc::Rc, sync::Arc};
14
15pub struct FakeServer {
16 peer: Arc<Peer>,
17 state: Arc<Mutex<FakeServerState>>,
18 user_id: u64,
19 executor: Rc<executor::Foreground>,
20}
21
22#[derive(Default)]
23struct FakeServerState {
24 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
25 connection_id: Option<ConnectionId>,
26 forbid_connections: bool,
27 auth_count: usize,
28 access_token: usize,
29}
30
31impl FakeServer {
32 pub async fn for_client(
33 client_user_id: u64,
34 client: &Arc<Client>,
35 cx: &TestAppContext,
36 ) -> Self {
37 let server = Self {
38 peer: Peer::new(),
39 state: Default::default(),
40 user_id: client_user_id,
41 executor: cx.foreground(),
42 };
43
44 client
45 .override_authenticate({
46 let state = Arc::downgrade(&server.state);
47 move |cx| {
48 let state = state.clone();
49 cx.spawn(move |_| async move {
50 let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
51 let mut state = state.lock();
52 state.auth_count += 1;
53 let access_token = state.access_token.to_string();
54 Ok(Credentials {
55 user_id: client_user_id,
56 access_token,
57 })
58 })
59 }
60 })
61 .override_establish_connection({
62 let peer = Arc::downgrade(&server.peer);
63 let state = Arc::downgrade(&server.state);
64 move |credentials, cx| {
65 let peer = peer.clone();
66 let state = state.clone();
67 let credentials = credentials.clone();
68 cx.spawn(move |cx| async move {
69 let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
70 let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
71 if state.lock().forbid_connections {
72 Err(EstablishConnectionError::Other(anyhow!(
73 "server is forbidding connections"
74 )))?
75 }
76
77 assert_eq!(credentials.user_id, client_user_id);
78
79 if credentials.access_token != state.lock().access_token.to_string() {
80 Err(EstablishConnectionError::Unauthorized)?
81 }
82
83 let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
84 let (connection_id, io, incoming) =
85 peer.add_test_connection(server_conn, cx.background()).await;
86 cx.background().spawn(io).detach();
87 let mut state = state.lock();
88 state.connection_id = Some(connection_id);
89 state.incoming = Some(incoming);
90 Ok(client_conn)
91 })
92 }
93 });
94
95 client
96 .authenticate_and_connect(false, &cx.to_async())
97 .await
98 .unwrap();
99
100 server
101 }
102
103 pub fn disconnect(&self) {
104 self.peer.disconnect(self.connection_id());
105 let mut state = self.state.lock();
106 state.connection_id.take();
107 state.incoming.take();
108 }
109
110 pub fn auth_count(&self) -> usize {
111 self.state.lock().auth_count
112 }
113
114 pub fn roll_access_token(&self) {
115 self.state.lock().access_token += 1;
116 }
117
118 pub fn forbid_connections(&self) {
119 self.state.lock().forbid_connections = true;
120 }
121
122 pub fn allow_connections(&self) {
123 self.state.lock().forbid_connections = false;
124 }
125
126 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
127 self.peer.send(self.connection_id(), message).unwrap();
128 }
129
130 #[allow(clippy::await_holding_lock)]
131 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
132 self.executor.start_waiting();
133
134 loop {
135 let message = self
136 .state
137 .lock()
138 .incoming
139 .as_mut()
140 .expect("not connected")
141 .next()
142 .await
143 .ok_or_else(|| anyhow!("other half hung up"))?;
144 self.executor.finish_waiting();
145 let type_name = message.payload_type_name();
146 let message = message.into_any();
147
148 if message.is::<TypedEnvelope<M>>() {
149 return Ok(*message.downcast().unwrap());
150 }
151
152 if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
153 self.respond(
154 message
155 .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
156 .unwrap()
157 .receipt(),
158 GetPrivateUserInfoResponse {
159 metrics_id: "the-metrics-id".into(),
160 staff: false,
161 },
162 )
163 .await;
164 continue;
165 }
166
167 panic!(
168 "fake server received unexpected message type: {:?}",
169 type_name
170 );
171 }
172 }
173
174 pub async fn respond<T: proto::RequestMessage>(
175 &self,
176 receipt: Receipt<T>,
177 response: T::Response,
178 ) {
179 self.peer.respond(receipt, response).unwrap()
180 }
181
182 fn connection_id(&self) -> ConnectionId {
183 self.state.lock().connection_id.expect("not connected")
184 }
185
186 pub async fn build_user_store(
187 &self,
188 client: Arc<Client>,
189 cx: &mut TestAppContext,
190 ) -> ModelHandle<UserStore> {
191 let http_client = FakeHttpClient::with_404_response();
192 let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
193 assert_eq!(
194 self.receive::<proto::GetUsers>()
195 .await
196 .unwrap()
197 .payload
198 .user_ids,
199 &[self.user_id]
200 );
201 user_store
202 }
203}
204
205impl Drop for FakeServer {
206 fn drop(&mut self) {
207 self.disconnect();
208 }
209}
210
211pub struct FakeHttpClient {
212 handler: Box<
213 dyn 'static
214 + Send
215 + Sync
216 + Fn(Request) -> BoxFuture<'static, Result<Response, http::Error>>,
217 >,
218}
219
220impl FakeHttpClient {
221 pub fn create<Fut, F>(handler: F) -> Arc<dyn HttpClient>
222 where
223 Fut: 'static + Send + Future<Output = Result<Response, http::Error>>,
224 F: 'static + Send + Sync + Fn(Request) -> Fut,
225 {
226 Arc::new(Self {
227 handler: Box::new(move |req| Box::pin(handler(req))),
228 })
229 }
230
231 pub fn with_404_response() -> Arc<dyn HttpClient> {
232 Self::create(|_| async move {
233 Ok(isahc::Response::builder()
234 .status(404)
235 .body(Default::default())
236 .unwrap())
237 })
238 }
239}
240
241impl fmt::Debug for FakeHttpClient {
242 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
243 f.debug_struct("FakeHttpClient").finish()
244 }
245}
246
247impl HttpClient for FakeHttpClient {
248 fn send(&self, req: Request) -> BoxFuture<Result<Response, crate::http::Error>> {
249 let future = (self.handler)(req);
250 Box::pin(async move { future.await.map(Into::into) })
251 }
252}