1use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
2use anyhow::{Context as _, Result, anyhow};
3use chrono::Duration;
4use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
5use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
6use futures::{StreamExt, stream::BoxStream};
7use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
8use http_client::{AsyncBody, Method, Request, http};
9use parking_lot::Mutex;
10use rpc::{
11 ConnectionId, Peer, Receipt, TypedEnvelope,
12 proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
13};
14use std::sync::Arc;
15
16pub struct FakeServer {
17 peer: Arc<Peer>,
18 state: Arc<Mutex<FakeServerState>>,
19 user_id: u64,
20 executor: BackgroundExecutor,
21}
22
23#[derive(Default)]
24struct FakeServerState {
25 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
26 connection_id: Option<ConnectionId>,
27 forbid_connections: bool,
28 auth_count: usize,
29 access_token: usize,
30}
31
32impl FakeServer {
33 pub async fn for_client(
34 client_user_id: u64,
35 client: &Arc<Client>,
36 cx: &TestAppContext,
37 ) -> Self {
38 let server = Self {
39 peer: Peer::new(0),
40 state: Default::default(),
41 user_id: client_user_id,
42 executor: cx.executor(),
43 };
44
45 client.http_client().as_fake().replace_handler({
46 let state = server.state.clone();
47 move |old_handler, req| {
48 let state = state.clone();
49 let old_handler = old_handler.clone();
50 async move {
51 match (req.method(), req.uri().path()) {
52 (&Method::GET, "/client/users/me") => {
53 let credentials = parse_authorization_header(&req);
54 if credentials
55 != Some(Credentials {
56 user_id: client_user_id,
57 access_token: state.lock().access_token.to_string(),
58 })
59 {
60 return Ok(http_client::Response::builder()
61 .status(401)
62 .body("Unauthorized".into())
63 .unwrap());
64 }
65
66 Ok(http_client::Response::builder()
67 .status(200)
68 .body(
69 serde_json::to_string(&make_get_authenticated_user_response(
70 client_user_id as i32,
71 format!("user-{client_user_id}"),
72 ))
73 .unwrap()
74 .into(),
75 )
76 .unwrap())
77 }
78 _ => old_handler(req).await,
79 }
80 }
81 }
82 });
83 client
84 .override_authenticate({
85 let state = Arc::downgrade(&server.state);
86 move |cx| {
87 let state = state.clone();
88 cx.spawn(async move |_| {
89 let state = state.upgrade().context("server dropped")?;
90 let mut state = state.lock();
91 state.auth_count += 1;
92 let access_token = state.access_token.to_string();
93 Ok(Credentials {
94 user_id: client_user_id,
95 access_token,
96 })
97 })
98 }
99 })
100 .override_establish_connection({
101 let peer = Arc::downgrade(&server.peer);
102 let state = Arc::downgrade(&server.state);
103 move |credentials, cx| {
104 let peer = peer.clone();
105 let state = state.clone();
106 let credentials = credentials.clone();
107 cx.spawn(async move |cx| {
108 let state = state.upgrade().context("server dropped")?;
109 let peer = peer.upgrade().context("server dropped")?;
110 if state.lock().forbid_connections {
111 Err(EstablishConnectionError::Other(anyhow!(
112 "server is forbidding connections"
113 )))?
114 }
115
116 if credentials
117 != (Credentials {
118 user_id: client_user_id,
119 access_token: state.lock().access_token.to_string(),
120 })
121 {
122 Err(EstablishConnectionError::Unauthorized)?
123 }
124
125 let (client_conn, server_conn, _) =
126 Connection::in_memory(cx.background_executor().clone());
127 let (connection_id, io, incoming) =
128 peer.add_test_connection(server_conn, cx.background_executor().clone());
129 cx.background_spawn(io).detach();
130 {
131 let mut state = state.lock();
132 state.connection_id = Some(connection_id);
133 state.incoming = Some(incoming);
134 }
135 peer.send(
136 connection_id,
137 proto::Hello {
138 peer_id: Some(connection_id.into()),
139 },
140 )
141 .unwrap();
142
143 Ok(client_conn)
144 })
145 }
146 });
147
148 client
149 .connect(false, &cx.to_async())
150 .await
151 .into_response()
152 .unwrap();
153
154 server
155 }
156
157 pub fn disconnect(&self) {
158 if self.state.lock().connection_id.is_some() {
159 self.peer.disconnect(self.connection_id());
160 let mut state = self.state.lock();
161 state.connection_id.take();
162 state.incoming.take();
163 }
164 }
165
166 pub fn auth_count(&self) -> usize {
167 self.state.lock().auth_count
168 }
169
170 pub fn roll_access_token(&self) {
171 self.state.lock().access_token += 1;
172 }
173
174 pub fn forbid_connections(&self) {
175 self.state.lock().forbid_connections = true;
176 }
177
178 pub fn allow_connections(&self) {
179 self.state.lock().forbid_connections = false;
180 }
181
182 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
183 self.peer.send(self.connection_id(), message).unwrap();
184 }
185
186 #[allow(clippy::await_holding_lock)]
187 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
188 self.executor.start_waiting();
189
190 loop {
191 let message = self
192 .state
193 .lock()
194 .incoming
195 .as_mut()
196 .expect("not connected")
197 .next()
198 .await
199 .context("other half hung up")?;
200 self.executor.finish_waiting();
201 let type_name = message.payload_type_name();
202 let message = message.into_any();
203
204 if message.is::<TypedEnvelope<M>>() {
205 return Ok(*message.downcast().unwrap());
206 }
207
208 let accepted_tos_at = chrono::Utc::now()
209 .checked_sub_signed(Duration::hours(5))
210 .expect("failed to build accepted_tos_at")
211 .timestamp() as u64;
212
213 if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
214 self.respond(
215 message
216 .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
217 .unwrap()
218 .receipt(),
219 GetPrivateUserInfoResponse {
220 metrics_id: "the-metrics-id".into(),
221 staff: false,
222 flags: Default::default(),
223 accepted_tos_at: Some(accepted_tos_at),
224 },
225 );
226 continue;
227 }
228
229 panic!(
230 "fake server received unexpected message type: {:?}",
231 type_name
232 );
233 }
234 }
235
236 pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
237 self.peer.respond(receipt, response).unwrap()
238 }
239
240 fn connection_id(&self) -> ConnectionId {
241 self.state.lock().connection_id.expect("not connected")
242 }
243
244 pub async fn build_user_store(
245 &self,
246 client: Arc<Client>,
247 cx: &mut TestAppContext,
248 ) -> Entity<UserStore> {
249 let user_store = cx.new(|cx| UserStore::new(client, cx));
250 assert_eq!(
251 self.receive::<proto::GetUsers>()
252 .await
253 .unwrap()
254 .payload
255 .user_ids,
256 &[self.user_id]
257 );
258 user_store
259 }
260}
261
262impl Drop for FakeServer {
263 fn drop(&mut self) {
264 self.disconnect();
265 }
266}
267
268pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
269 let mut auth_header = req
270 .headers()
271 .get(http::header::AUTHORIZATION)?
272 .to_str()
273 .ok()?
274 .split_whitespace();
275 let user_id = auth_header.next()?.parse().ok()?;
276 let access_token = auth_header.next()?;
277 Some(Credentials {
278 user_id,
279 access_token: access_token.to_string(),
280 })
281}
282
283pub fn make_get_authenticated_user_response(
284 user_id: i32,
285 github_login: String,
286) -> GetAuthenticatedUserResponse {
287 GetAuthenticatedUserResponse {
288 user: AuthenticatedUser {
289 id: user_id,
290 metrics_id: format!("metrics-id-{user_id}"),
291 avatar_url: "".to_string(),
292 github_login,
293 name: None,
294 is_staff: false,
295 accepted_tos_at: None,
296 },
297 feature_flags: vec![],
298 plan: PlanInfo {
299 plan: Plan::ZedPro,
300 subscription_period: None,
301 usage: CurrentUsage {
302 model_requests: UsageData {
303 used: 0,
304 limit: UsageLimit::Limited(500),
305 },
306 edit_predictions: UsageData {
307 used: 250,
308 limit: UsageLimit::Unlimited,
309 },
310 },
311 trial_started_at: None,
312 is_usage_based_billing_enabled: false,
313 is_account_too_young: false,
314 has_overdue_invoices: false,
315 },
316 }
317}