1use crate::{
2 http::{HttpClient, Request, Response, ServerResponse},
3 Client, Connection, Credentials, EstablishConnectionError, UserStore,
4};
5use anyhow::{anyhow, Result};
6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
7use gpui::{executor, ModelHandle, TestAppContext};
8use parking_lot::Mutex;
9use postage::barrier;
10use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
11use std::{fmt, rc::Rc, sync::Arc};
12
13pub struct FakeServer {
14 peer: Arc<Peer>,
15 state: Arc<Mutex<FakeServerState>>,
16 user_id: u64,
17 executor: Rc<executor::Foreground>,
18}
19
20#[derive(Default)]
21struct FakeServerState {
22 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
23 connection_id: Option<ConnectionId>,
24 forbid_connections: bool,
25 auth_count: usize,
26 connection_killer: Option<barrier::Sender>,
27 access_token: usize,
28}
29
30impl FakeServer {
31 pub async fn for_client(
32 client_user_id: u64,
33 client: &mut Arc<Client>,
34 cx: &TestAppContext,
35 ) -> Self {
36 let server = Self {
37 peer: Peer::new(),
38 state: Default::default(),
39 user_id: client_user_id,
40 executor: cx.foreground(),
41 };
42
43 Arc::get_mut(client)
44 .unwrap()
45 .override_authenticate({
46 let state = server.state.clone();
47 move |cx| {
48 let mut state = state.lock();
49 state.auth_count += 1;
50 let access_token = state.access_token.to_string();
51 cx.spawn(move |_| async move {
52 Ok(Credentials {
53 user_id: client_user_id,
54 access_token,
55 })
56 })
57 }
58 })
59 .override_establish_connection({
60 let peer = server.peer.clone();
61 let state = server.state.clone();
62 move |credentials, cx| {
63 let peer = peer.clone();
64 let state = state.clone();
65 let credentials = credentials.clone();
66 cx.spawn(move |cx| async move {
67 assert_eq!(credentials.user_id, client_user_id);
68
69 if state.lock().forbid_connections {
70 Err(EstablishConnectionError::Other(anyhow!(
71 "server is forbidding connections"
72 )))?
73 }
74
75 if credentials.access_token != state.lock().access_token.to_string() {
76 Err(EstablishConnectionError::Unauthorized)?
77 }
78
79 let (client_conn, server_conn, kill) =
80 Connection::in_memory(cx.background());
81 let (connection_id, io, incoming) =
82 peer.add_test_connection(server_conn, cx.background()).await;
83 cx.background().spawn(io).detach();
84 let mut state = state.lock();
85 state.connection_id = Some(connection_id);
86 state.incoming = Some(incoming);
87 state.connection_killer = Some(kill);
88 Ok(client_conn)
89 })
90 }
91 });
92
93 client
94 .authenticate_and_connect(false, &cx.to_async())
95 .await
96 .unwrap();
97 server
98 }
99
100 pub fn disconnect(&self) {
101 self.peer.disconnect(self.connection_id());
102 let mut state = self.state.lock();
103 state.connection_id.take();
104 state.incoming.take();
105 }
106
107 pub fn auth_count(&self) -> usize {
108 self.state.lock().auth_count
109 }
110
111 pub fn roll_access_token(&self) {
112 self.state.lock().access_token += 1;
113 }
114
115 pub fn forbid_connections(&self) {
116 self.state.lock().forbid_connections = true;
117 }
118
119 pub fn allow_connections(&self) {
120 self.state.lock().forbid_connections = false;
121 }
122
123 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
124 self.peer.send(self.connection_id(), message).unwrap();
125 }
126
127 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
128 self.executor.start_waiting();
129 let message = self
130 .state
131 .lock()
132 .incoming
133 .as_mut()
134 .expect("not connected")
135 .next()
136 .await
137 .ok_or_else(|| anyhow!("other half hung up"))?;
138 self.executor.finish_waiting();
139 let type_name = message.payload_type_name();
140 Ok(*message
141 .into_any()
142 .downcast::<TypedEnvelope<M>>()
143 .unwrap_or_else(|_| {
144 panic!(
145 "fake server received unexpected message type: {:?}",
146 type_name
147 );
148 }))
149 }
150
151 pub async fn respond<T: proto::RequestMessage>(
152 &self,
153 receipt: Receipt<T>,
154 response: T::Response,
155 ) {
156 self.peer.respond(receipt, response).unwrap()
157 }
158
159 fn connection_id(&self) -> ConnectionId {
160 self.state.lock().connection_id.expect("not connected")
161 }
162
163 pub async fn build_user_store(
164 &self,
165 client: Arc<Client>,
166 cx: &mut TestAppContext,
167 ) -> ModelHandle<UserStore> {
168 let http_client = FakeHttpClient::with_404_response();
169 let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
170 assert_eq!(
171 self.receive::<proto::GetUsers>()
172 .await
173 .unwrap()
174 .payload
175 .user_ids,
176 &[self.user_id]
177 );
178 user_store
179 }
180}
181
182pub struct FakeHttpClient {
183 handler:
184 Box<dyn 'static + Send + Sync + Fn(Request) -> BoxFuture<'static, Result<ServerResponse>>>,
185}
186
187impl FakeHttpClient {
188 pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
189 where
190 Fut: 'static + Send + Future<Output = Result<ServerResponse>>,
191 F: 'static + Send + Sync + Fn(Request) -> Fut,
192 {
193 Arc::new(Self {
194 handler: Box::new(move |req| Box::pin(handler(req))),
195 })
196 }
197
198 pub fn with_404_response() -> Arc<dyn HttpClient> {
199 Self::new(|_| async move { Ok(ServerResponse::new(404)) })
200 }
201}
202
203impl fmt::Debug for FakeHttpClient {
204 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205 f.debug_struct("FakeHttpClient").finish()
206 }
207}
208
209impl HttpClient for FakeHttpClient {
210 fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response>> {
211 let future = (self.handler)(req);
212 Box::pin(async move { future.await.map(Into::into) })
213 }
214}