1use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
2use anyhow::{Context as _, Result, anyhow};
3use chrono::Duration;
4use futures::{StreamExt, stream::BoxStream};
5use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
6use parking_lot::Mutex;
7use rpc::{
8 ConnectionId, Peer, Receipt, TypedEnvelope,
9 proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
10};
11use std::sync::Arc;
12
13pub struct FakeServer {
14 peer: Arc<Peer>,
15 state: Arc<Mutex<FakeServerState>>,
16 user_id: u64,
17 executor: BackgroundExecutor,
18}
19
20#[derive(Default)]
21struct FakeServerState {
22 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
23 connection_id: Option<ConnectionId>,
24 forbid_connections: bool,
25 auth_count: usize,
26 access_token: usize,
27}
28
29impl FakeServer {
30 pub async fn for_client(
31 client_user_id: u64,
32 client: &Arc<Client>,
33 cx: &TestAppContext,
34 ) -> Self {
35 let server = Self {
36 peer: Peer::new(0),
37 state: Default::default(),
38 user_id: client_user_id,
39 executor: cx.executor(),
40 };
41
42 client
43 .override_authenticate({
44 let state = Arc::downgrade(&server.state);
45 move |cx| {
46 let state = state.clone();
47 cx.spawn(async move |_| {
48 let state = state.upgrade().context("server dropped")?;
49 let mut state = state.lock();
50 state.auth_count += 1;
51 let access_token = state.access_token.to_string();
52 Ok(Credentials {
53 user_id: client_user_id,
54 access_token,
55 })
56 })
57 }
58 })
59 .override_establish_connection({
60 let peer = Arc::downgrade(&server.peer);
61 let state = Arc::downgrade(&server.state);
62 move |credentials, cx| {
63 let peer = peer.clone();
64 let state = state.clone();
65 let credentials = credentials.clone();
66 cx.spawn(async move |cx| {
67 let state = state.upgrade().context("server dropped")?;
68 let peer = peer.upgrade().context("server dropped")?;
69 if state.lock().forbid_connections {
70 Err(EstablishConnectionError::Other(anyhow!(
71 "server is forbidding connections"
72 )))?
73 }
74
75 if credentials
76 != (Credentials {
77 user_id: client_user_id,
78 access_token: state.lock().access_token.to_string(),
79 })
80 {
81 Err(EstablishConnectionError::Unauthorized)?
82 }
83
84 let (client_conn, server_conn, _) =
85 Connection::in_memory(cx.background_executor().clone());
86 let (connection_id, io, incoming) =
87 peer.add_test_connection(server_conn, cx.background_executor().clone());
88 cx.background_spawn(io).detach();
89 {
90 let mut state = state.lock();
91 state.connection_id = Some(connection_id);
92 state.incoming = Some(incoming);
93 }
94 peer.send(
95 connection_id,
96 proto::Hello {
97 peer_id: Some(connection_id.into()),
98 },
99 )
100 .unwrap();
101
102 Ok(client_conn)
103 })
104 }
105 });
106
107 client
108 .authenticate_and_connect(false, &cx.to_async())
109 .await
110 .into_response()
111 .unwrap();
112
113 server
114 }
115
116 pub fn disconnect(&self) {
117 if self.state.lock().connection_id.is_some() {
118 self.peer.disconnect(self.connection_id());
119 let mut state = self.state.lock();
120 state.connection_id.take();
121 state.incoming.take();
122 }
123 }
124
125 pub fn auth_count(&self) -> usize {
126 self.state.lock().auth_count
127 }
128
129 pub fn roll_access_token(&self) {
130 self.state.lock().access_token += 1;
131 }
132
133 pub fn forbid_connections(&self) {
134 self.state.lock().forbid_connections = true;
135 }
136
137 pub fn allow_connections(&self) {
138 self.state.lock().forbid_connections = false;
139 }
140
141 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
142 self.peer.send(self.connection_id(), message).unwrap();
143 }
144
145 #[allow(clippy::await_holding_lock)]
146 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
147 self.executor.start_waiting();
148
149 loop {
150 let message = self
151 .state
152 .lock()
153 .incoming
154 .as_mut()
155 .expect("not connected")
156 .next()
157 .await
158 .context("other half hung up")?;
159 self.executor.finish_waiting();
160 let type_name = message.payload_type_name();
161 let message = message.into_any();
162
163 if message.is::<TypedEnvelope<M>>() {
164 return Ok(*message.downcast().unwrap());
165 }
166
167 let accepted_tos_at = chrono::Utc::now()
168 .checked_sub_signed(Duration::hours(5))
169 .expect("failed to build accepted_tos_at")
170 .timestamp() as u64;
171
172 if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
173 self.respond(
174 message
175 .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
176 .unwrap()
177 .receipt(),
178 GetPrivateUserInfoResponse {
179 metrics_id: "the-metrics-id".into(),
180 staff: false,
181 flags: Default::default(),
182 accepted_tos_at: Some(accepted_tos_at),
183 },
184 );
185 continue;
186 }
187
188 panic!(
189 "fake server received unexpected message type: {:?}",
190 type_name
191 );
192 }
193 }
194
195 pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
196 self.peer.respond(receipt, response).unwrap()
197 }
198
199 fn connection_id(&self) -> ConnectionId {
200 self.state.lock().connection_id.expect("not connected")
201 }
202
203 pub async fn build_user_store(
204 &self,
205 client: Arc<Client>,
206 cx: &mut TestAppContext,
207 ) -> Entity<UserStore> {
208 let user_store = cx.new(|cx| UserStore::new(client, cx));
209 assert_eq!(
210 self.receive::<proto::GetUsers>()
211 .await
212 .unwrap()
213 .payload
214 .user_ids,
215 &[self.user_id]
216 );
217 user_store
218 }
219}
220
221impl Drop for FakeServer {
222 fn drop(&mut self) {
223 self.disconnect();
224 }
225}