1use crate::{
2 http::{self, HttpClient, Request, Response},
3 Client, Connection, Credentials, EstablishConnectionError, UserStore,
4};
5use anyhow::{anyhow, Result};
6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
7use gpui::{executor, ModelHandle, TestAppContext};
8use parking_lot::Mutex;
9use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
10use std::{fmt, rc::Rc, sync::Arc};
11
12pub struct FakeServer {
13 peer: Arc<Peer>,
14 state: Arc<Mutex<FakeServerState>>,
15 user_id: u64,
16 executor: Rc<executor::Foreground>,
17}
18
19#[derive(Default)]
20struct FakeServerState {
21 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
22 connection_id: Option<ConnectionId>,
23 forbid_connections: bool,
24 auth_count: usize,
25 access_token: usize,
26}
27
28impl FakeServer {
29 pub async fn for_client(
30 client_user_id: u64,
31 client: &Arc<Client>,
32 cx: &TestAppContext,
33 ) -> Self {
34 let server = Self {
35 peer: Peer::new(),
36 state: Default::default(),
37 user_id: client_user_id,
38 executor: cx.foreground(),
39 };
40
41 client
42 .override_authenticate({
43 let state = Arc::downgrade(&server.state);
44 move |cx| {
45 let state = state.clone();
46 cx.spawn(move |_| async move {
47 let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
48 let mut state = state.lock();
49 state.auth_count += 1;
50 let access_token = state.access_token.to_string();
51 Ok(Credentials {
52 user_id: client_user_id,
53 access_token,
54 })
55 })
56 }
57 })
58 .override_establish_connection({
59 let peer = Arc::downgrade(&server.peer).clone();
60 let state = Arc::downgrade(&server.state);
61 move |credentials, cx| {
62 let peer = peer.clone();
63 let state = state.clone();
64 let credentials = credentials.clone();
65 cx.spawn(move |cx| async move {
66 let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
67 let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
68 if state.lock().forbid_connections {
69 Err(EstablishConnectionError::Other(anyhow!(
70 "server is forbidding connections"
71 )))?
72 }
73
74 assert_eq!(credentials.user_id, client_user_id);
75
76 if credentials.access_token != state.lock().access_token.to_string() {
77 Err(EstablishConnectionError::Unauthorized)?
78 }
79
80 let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
81 let (connection_id, io, incoming) =
82 peer.add_test_connection(server_conn, cx.background()).await;
83 cx.background().spawn(io).detach();
84 let mut state = state.lock();
85 state.connection_id = Some(connection_id);
86 state.incoming = Some(incoming);
87 Ok(client_conn)
88 })
89 }
90 });
91
92 client
93 .authenticate_and_connect(false, &cx.to_async())
94 .await
95 .unwrap();
96 server
97 }
98
99 pub fn disconnect(&self) {
100 self.peer.disconnect(self.connection_id());
101 let mut state = self.state.lock();
102 state.connection_id.take();
103 state.incoming.take();
104 }
105
106 pub fn auth_count(&self) -> usize {
107 self.state.lock().auth_count
108 }
109
110 pub fn roll_access_token(&self) {
111 self.state.lock().access_token += 1;
112 }
113
114 pub fn forbid_connections(&self) {
115 self.state.lock().forbid_connections = true;
116 }
117
118 pub fn allow_connections(&self) {
119 self.state.lock().forbid_connections = false;
120 }
121
122 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
123 self.peer.send(self.connection_id(), message).unwrap();
124 }
125
126 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
127 self.executor.start_waiting();
128 let message = self
129 .state
130 .lock()
131 .incoming
132 .as_mut()
133 .expect("not connected")
134 .next()
135 .await
136 .ok_or_else(|| anyhow!("other half hung up"))?;
137 self.executor.finish_waiting();
138 let type_name = message.payload_type_name();
139 Ok(*message
140 .into_any()
141 .downcast::<TypedEnvelope<M>>()
142 .unwrap_or_else(|_| {
143 panic!(
144 "fake server received unexpected message type: {:?}",
145 type_name
146 );
147 }))
148 }
149
150 pub async fn respond<T: proto::RequestMessage>(
151 &self,
152 receipt: Receipt<T>,
153 response: T::Response,
154 ) {
155 self.peer.respond(receipt, response).unwrap()
156 }
157
158 fn connection_id(&self) -> ConnectionId {
159 self.state.lock().connection_id.expect("not connected")
160 }
161
162 pub async fn build_user_store(
163 &self,
164 client: Arc<Client>,
165 cx: &mut TestAppContext,
166 ) -> ModelHandle<UserStore> {
167 let http_client = FakeHttpClient::with_404_response();
168 let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
169 assert_eq!(
170 self.receive::<proto::GetUsers>()
171 .await
172 .unwrap()
173 .payload
174 .user_ids,
175 &[self.user_id]
176 );
177 user_store
178 }
179}
180
181impl Drop for FakeServer {
182 fn drop(&mut self) {
183 self.disconnect();
184 }
185}
186
187pub struct FakeHttpClient {
188 handler: Box<
189 dyn 'static
190 + Send
191 + Sync
192 + Fn(Request) -> BoxFuture<'static, Result<Response, http::Error>>,
193 >,
194}
195
196impl FakeHttpClient {
197 pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
198 where
199 Fut: 'static + Send + Future<Output = Result<Response, http::Error>>,
200 F: 'static + Send + Sync + Fn(Request) -> Fut,
201 {
202 Arc::new(Self {
203 handler: Box::new(move |req| Box::pin(handler(req))),
204 })
205 }
206
207 pub fn with_404_response() -> Arc<dyn HttpClient> {
208 Self::new(|_| async move {
209 Ok(isahc::Response::builder()
210 .status(404)
211 .body(Default::default())
212 .unwrap())
213 })
214 }
215}
216
217impl fmt::Debug for FakeHttpClient {
218 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219 f.debug_struct("FakeHttpClient").finish()
220 }
221}
222
223impl HttpClient for FakeHttpClient {
224 fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response, crate::http::Error>> {
225 let future = (self.handler)(req);
226 Box::pin(async move { future.await.map(Into::into) })
227 }
228}