1use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
2use anyhow::{anyhow, Result};
3use futures::{stream::BoxStream, StreamExt};
4use gpui::{BackgroundExecutor, Context, Model, TestAppContext};
5use parking_lot::Mutex;
6use rpc::{
7 proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
8 ConnectionId, Peer, Receipt, TypedEnvelope,
9};
10use std::sync::Arc;
11
12pub struct FakeServer {
13 peer: Arc<Peer>,
14 state: Arc<Mutex<FakeServerState>>,
15 user_id: u64,
16 executor: BackgroundExecutor,
17}
18
19#[derive(Default)]
20struct FakeServerState {
21 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
22 connection_id: Option<ConnectionId>,
23 forbid_connections: bool,
24 auth_count: usize,
25 access_token: usize,
26}
27
28impl FakeServer {
29 pub async fn for_client(
30 client_user_id: u64,
31 client: &Arc<Client>,
32 cx: &TestAppContext,
33 ) -> Self {
34 let server = Self {
35 peer: Peer::new(0),
36 state: Default::default(),
37 user_id: client_user_id,
38 executor: cx.executor(),
39 };
40
41 client
42 .override_authenticate({
43 let state = Arc::downgrade(&server.state);
44 move |cx| {
45 let state = state.clone();
46 cx.spawn(move |_| async move {
47 let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
48 let mut state = state.lock();
49 state.auth_count += 1;
50 let access_token = state.access_token.to_string();
51 Ok(Credentials::User {
52 user_id: client_user_id,
53 access_token,
54 })
55 })
56 }
57 })
58 .override_establish_connection({
59 let peer = Arc::downgrade(&server.peer);
60 let state = Arc::downgrade(&server.state);
61 move |credentials, cx| {
62 let peer = peer.clone();
63 let state = state.clone();
64 let credentials = credentials.clone();
65 cx.spawn(move |cx| async move {
66 let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
67 let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
68 if state.lock().forbid_connections {
69 Err(EstablishConnectionError::Other(anyhow!(
70 "server is forbidding connections"
71 )))?
72 }
73
74 if credentials
75 != (Credentials::User {
76 user_id: client_user_id,
77 access_token: state.lock().access_token.to_string(),
78 })
79 {
80 Err(EstablishConnectionError::Unauthorized)?
81 }
82
83 let (client_conn, server_conn, _) =
84 Connection::in_memory(cx.background_executor().clone());
85 let (connection_id, io, incoming) =
86 peer.add_test_connection(server_conn, cx.background_executor().clone());
87 cx.background_executor().spawn(io).detach();
88 {
89 let mut state = state.lock();
90 state.connection_id = Some(connection_id);
91 state.incoming = Some(incoming);
92 }
93 peer.send(
94 connection_id,
95 proto::Hello {
96 peer_id: Some(connection_id.into()),
97 },
98 )
99 .unwrap();
100
101 Ok(client_conn)
102 })
103 }
104 });
105
106 client
107 .authenticate_and_connect(false, &cx.to_async())
108 .await
109 .unwrap();
110
111 server
112 }
113
114 pub fn disconnect(&self) {
115 if self.state.lock().connection_id.is_some() {
116 self.peer.disconnect(self.connection_id());
117 let mut state = self.state.lock();
118 state.connection_id.take();
119 state.incoming.take();
120 }
121 }
122
123 pub fn auth_count(&self) -> usize {
124 self.state.lock().auth_count
125 }
126
127 pub fn roll_access_token(&self) {
128 self.state.lock().access_token += 1;
129 }
130
131 pub fn forbid_connections(&self) {
132 self.state.lock().forbid_connections = true;
133 }
134
135 pub fn allow_connections(&self) {
136 self.state.lock().forbid_connections = false;
137 }
138
139 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
140 self.peer.send(self.connection_id(), message).unwrap();
141 }
142
143 #[allow(clippy::await_holding_lock)]
144 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
145 self.executor.start_waiting();
146
147 loop {
148 let message = self
149 .state
150 .lock()
151 .incoming
152 .as_mut()
153 .expect("not connected")
154 .next()
155 .await
156 .ok_or_else(|| anyhow!("other half hung up"))?;
157 self.executor.finish_waiting();
158 let type_name = message.payload_type_name();
159 let message = message.into_any();
160
161 if message.is::<TypedEnvelope<M>>() {
162 return Ok(*message.downcast().unwrap());
163 }
164
165 if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
166 self.respond(
167 message
168 .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
169 .unwrap()
170 .receipt(),
171 GetPrivateUserInfoResponse {
172 metrics_id: "the-metrics-id".into(),
173 staff: false,
174 flags: Default::default(),
175 },
176 );
177 continue;
178 }
179
180 panic!(
181 "fake server received unexpected message type: {:?}",
182 type_name
183 );
184 }
185 }
186
187 pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
188 self.peer.respond(receipt, response).unwrap()
189 }
190
191 fn connection_id(&self) -> ConnectionId {
192 self.state.lock().connection_id.expect("not connected")
193 }
194
195 pub async fn build_user_store(
196 &self,
197 client: Arc<Client>,
198 cx: &mut TestAppContext,
199 ) -> Model<UserStore> {
200 let user_store = cx.new_model(|cx| UserStore::new(client, cx));
201 assert_eq!(
202 self.receive::<proto::GetUsers>()
203 .await
204 .unwrap()
205 .payload
206 .user_ids,
207 &[self.user_id]
208 );
209 user_store
210 }
211}
212
213impl Drop for FakeServer {
214 fn drop(&mut self) {
215 self.disconnect();
216 }
217}