1use crate::{
2 http::{self, HttpClient, Request, Response},
3 Client, Connection, Credentials, EstablishConnectionError, UserStore,
4};
5use anyhow::{anyhow, Result};
6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
7use gpui::{executor, ModelHandle, TestAppContext};
8use parking_lot::Mutex;
9use rpc::{
10 proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
11 ConnectionId, Peer, Receipt, TypedEnvelope,
12};
13use std::{fmt, rc::Rc, sync::Arc};
14
15pub struct FakeServer {
16 peer: Arc<Peer>,
17 state: Arc<Mutex<FakeServerState>>,
18 user_id: u64,
19 executor: Rc<executor::Foreground>,
20}
21
22#[derive(Default)]
23struct FakeServerState {
24 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
25 connection_id: Option<ConnectionId>,
26 forbid_connections: bool,
27 auth_count: usize,
28 access_token: usize,
29}
30
31impl FakeServer {
32 pub async fn for_client(
33 client_user_id: u64,
34 client: &Arc<Client>,
35 cx: &TestAppContext,
36 ) -> Self {
37 let server = Self {
38 peer: Peer::new(),
39 state: Default::default(),
40 user_id: client_user_id,
41 executor: cx.foreground(),
42 };
43
44 client
45 .override_authenticate({
46 let state = Arc::downgrade(&server.state);
47 move |cx| {
48 let state = state.clone();
49 cx.spawn(move |_| async move {
50 let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
51 let mut state = state.lock();
52 state.auth_count += 1;
53 let access_token = state.access_token.to_string();
54 Ok(Credentials {
55 user_id: client_user_id,
56 access_token,
57 })
58 })
59 }
60 })
61 .override_establish_connection({
62 let peer = Arc::downgrade(&server.peer);
63 let state = Arc::downgrade(&server.state);
64 move |credentials, cx| {
65 let peer = peer.clone();
66 let state = state.clone();
67 let credentials = credentials.clone();
68 cx.spawn(move |cx| async move {
69 let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
70 let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
71 if state.lock().forbid_connections {
72 Err(EstablishConnectionError::Other(anyhow!(
73 "server is forbidding connections"
74 )))?
75 }
76
77 assert_eq!(credentials.user_id, client_user_id);
78
79 if credentials.access_token != state.lock().access_token.to_string() {
80 Err(EstablishConnectionError::Unauthorized)?
81 }
82
83 let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
84 let (connection_id, io, incoming) =
85 peer.add_test_connection(server_conn, cx.background()).await;
86 cx.background().spawn(io).detach();
87 let mut state = state.lock();
88 state.connection_id = Some(connection_id);
89 state.incoming = Some(incoming);
90 Ok(client_conn)
91 })
92 }
93 });
94
95 client
96 .authenticate_and_connect(false, &cx.to_async())
97 .await
98 .unwrap();
99
100 server
101 }
102
103 pub fn disconnect(&self) {
104 self.peer.disconnect(self.connection_id());
105 let mut state = self.state.lock();
106 state.connection_id.take();
107 state.incoming.take();
108 }
109
110 pub fn auth_count(&self) -> usize {
111 self.state.lock().auth_count
112 }
113
114 pub fn roll_access_token(&self) {
115 self.state.lock().access_token += 1;
116 }
117
118 pub fn forbid_connections(&self) {
119 self.state.lock().forbid_connections = true;
120 }
121
122 pub fn allow_connections(&self) {
123 self.state.lock().forbid_connections = false;
124 }
125
126 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
127 self.peer.send(self.connection_id(), message).unwrap();
128 }
129
130 #[allow(clippy::await_holding_lock)]
131 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
132 self.executor.start_waiting();
133
134 loop {
135 let message = self
136 .state
137 .lock()
138 .incoming
139 .as_mut()
140 .expect("not connected")
141 .next()
142 .await
143 .ok_or_else(|| anyhow!("other half hung up"))?;
144 self.executor.finish_waiting();
145 let type_name = message.payload_type_name();
146 let message = message.into_any();
147
148 if message.is::<TypedEnvelope<M>>() {
149 return Ok(*message.downcast().unwrap());
150 }
151
152 if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
153 self.respond(
154 message
155 .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
156 .unwrap()
157 .receipt(),
158 GetPrivateUserInfoResponse {
159 metrics_id: "the-metrics-id".into(),
160 },
161 )
162 .await;
163 continue;
164 }
165
166 panic!(
167 "fake server received unexpected message type: {:?}",
168 type_name
169 );
170 }
171 }
172
173 pub async fn respond<T: proto::RequestMessage>(
174 &self,
175 receipt: Receipt<T>,
176 response: T::Response,
177 ) {
178 self.peer.respond(receipt, response).unwrap()
179 }
180
181 fn connection_id(&self) -> ConnectionId {
182 self.state.lock().connection_id.expect("not connected")
183 }
184
185 pub async fn build_user_store(
186 &self,
187 client: Arc<Client>,
188 cx: &mut TestAppContext,
189 ) -> ModelHandle<UserStore> {
190 let http_client = FakeHttpClient::with_404_response();
191 let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
192 assert_eq!(
193 self.receive::<proto::GetUsers>()
194 .await
195 .unwrap()
196 .payload
197 .user_ids,
198 &[self.user_id]
199 );
200 user_store
201 }
202}
203
204impl Drop for FakeServer {
205 fn drop(&mut self) {
206 self.disconnect();
207 }
208}
209
210pub struct FakeHttpClient {
211 handler: Box<
212 dyn 'static
213 + Send
214 + Sync
215 + Fn(Request) -> BoxFuture<'static, Result<Response, http::Error>>,
216 >,
217}
218
219impl FakeHttpClient {
220 pub fn create<Fut, F>(handler: F) -> Arc<dyn HttpClient>
221 where
222 Fut: 'static + Send + Future<Output = Result<Response, http::Error>>,
223 F: 'static + Send + Sync + Fn(Request) -> Fut,
224 {
225 Arc::new(Self {
226 handler: Box::new(move |req| Box::pin(handler(req))),
227 })
228 }
229
230 pub fn with_404_response() -> Arc<dyn HttpClient> {
231 Self::create(|_| async move {
232 Ok(isahc::Response::builder()
233 .status(404)
234 .body(Default::default())
235 .unwrap())
236 })
237 }
238}
239
240impl fmt::Debug for FakeHttpClient {
241 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
242 f.debug_struct("FakeHttpClient").finish()
243 }
244}
245
246impl HttpClient for FakeHttpClient {
247 fn send(&self, req: Request) -> BoxFuture<Result<Response, crate::http::Error>> {
248 let future = (self.handler)(req);
249 Box::pin(async move { future.await.map(Into::into) })
250 }
251}