1use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
2use anyhow::{anyhow, Result};
3use futures::{stream::BoxStream, StreamExt};
4use gpui::{BackgroundExecutor, Context, Model, TestAppContext};
5use parking_lot::Mutex;
6use rpc::{
7 proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
8 ConnectionId, Peer, Receipt, TypedEnvelope,
9};
10use std::sync::Arc;
11use util::http::FakeHttpClient;
12
13pub struct FakeServer {
14 peer: Arc<Peer>,
15 state: Arc<Mutex<FakeServerState>>,
16 user_id: u64,
17 executor: BackgroundExecutor,
18}
19
20#[derive(Default)]
21struct FakeServerState {
22 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
23 connection_id: Option<ConnectionId>,
24 forbid_connections: bool,
25 auth_count: usize,
26 access_token: usize,
27}
28
29impl FakeServer {
30 pub async fn for_client(
31 client_user_id: u64,
32 client: &Arc<Client>,
33 cx: &TestAppContext,
34 ) -> Self {
35 let server = Self {
36 peer: Peer::new(0),
37 state: Default::default(),
38 user_id: client_user_id,
39 executor: cx.executor(),
40 };
41
42 client
43 .override_authenticate({
44 let state = Arc::downgrade(&server.state);
45 move |cx| {
46 let state = state.clone();
47 cx.spawn(move |_| async move {
48 let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
49 let mut state = state.lock();
50 state.auth_count += 1;
51 let access_token = state.access_token.to_string();
52 Ok(Credentials {
53 user_id: client_user_id,
54 access_token,
55 })
56 })
57 }
58 })
59 .override_establish_connection({
60 let peer = Arc::downgrade(&server.peer);
61 let state = Arc::downgrade(&server.state);
62 move |credentials, cx| {
63 let peer = peer.clone();
64 let state = state.clone();
65 let credentials = credentials.clone();
66 cx.spawn(move |cx| async move {
67 let state = state.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
68 let peer = peer.upgrade().ok_or_else(|| anyhow!("server dropped"))?;
69 if state.lock().forbid_connections {
70 Err(EstablishConnectionError::Other(anyhow!(
71 "server is forbidding connections"
72 )))?
73 }
74
75 assert_eq!(credentials.user_id, client_user_id);
76
77 if credentials.access_token != state.lock().access_token.to_string() {
78 Err(EstablishConnectionError::Unauthorized)?
79 }
80
81 let (client_conn, server_conn, _) =
82 Connection::in_memory(cx.background_executor().clone());
83 let (connection_id, io, incoming) =
84 peer.add_test_connection(server_conn, cx.background_executor().clone());
85 cx.background_executor().spawn(io).detach();
86 {
87 let mut state = state.lock();
88 state.connection_id = Some(connection_id);
89 state.incoming = Some(incoming);
90 }
91 peer.send(
92 connection_id,
93 proto::Hello {
94 peer_id: Some(connection_id.into()),
95 },
96 )
97 .unwrap();
98
99 Ok(client_conn)
100 })
101 }
102 });
103
104 client
105 .authenticate_and_connect(false, &cx.to_async())
106 .await
107 .unwrap();
108
109 server
110 }
111
112 pub fn disconnect(&self) {
113 if self.state.lock().connection_id.is_some() {
114 self.peer.disconnect(self.connection_id());
115 let mut state = self.state.lock();
116 state.connection_id.take();
117 state.incoming.take();
118 }
119 }
120
121 pub fn auth_count(&self) -> usize {
122 self.state.lock().auth_count
123 }
124
125 pub fn roll_access_token(&self) {
126 self.state.lock().access_token += 1;
127 }
128
129 pub fn forbid_connections(&self) {
130 self.state.lock().forbid_connections = true;
131 }
132
133 pub fn allow_connections(&self) {
134 self.state.lock().forbid_connections = false;
135 }
136
137 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
138 self.peer.send(self.connection_id(), message).unwrap();
139 }
140
141 #[allow(clippy::await_holding_lock)]
142 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
143 self.executor.start_waiting();
144
145 loop {
146 let message = self
147 .state
148 .lock()
149 .incoming
150 .as_mut()
151 .expect("not connected")
152 .next()
153 .await
154 .ok_or_else(|| anyhow!("other half hung up"))?;
155 self.executor.finish_waiting();
156 let type_name = message.payload_type_name();
157 let message = message.into_any();
158
159 if message.is::<TypedEnvelope<M>>() {
160 return Ok(*message.downcast().unwrap());
161 }
162
163 if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
164 self.respond(
165 message
166 .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
167 .unwrap()
168 .receipt(),
169 GetPrivateUserInfoResponse {
170 metrics_id: "the-metrics-id".into(),
171 staff: false,
172 flags: Default::default(),
173 },
174 );
175 continue;
176 }
177
178 panic!(
179 "fake server received unexpected message type: {:?}",
180 type_name
181 );
182 }
183 }
184
185 pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
186 self.peer.respond(receipt, response).unwrap()
187 }
188
189 fn connection_id(&self) -> ConnectionId {
190 self.state.lock().connection_id.expect("not connected")
191 }
192
193 pub async fn build_user_store(
194 &self,
195 client: Arc<Client>,
196 cx: &mut TestAppContext,
197 ) -> Model<UserStore> {
198 let http_client = FakeHttpClient::with_404_response();
199 let user_store = cx.build_model(|cx| UserStore::new(client, http_client, cx));
200 assert_eq!(
201 self.receive::<proto::GetUsers>()
202 .await
203 .unwrap()
204 .payload
205 .user_ids,
206 &[self.user_id]
207 );
208 user_store
209 }
210}
211
212impl Drop for FakeServer {
213 fn drop(&mut self) {
214 self.disconnect();
215 }
216}