1use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
2use anyhow::{Context as _, Result, anyhow};
3use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
4use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
5use futures::{StreamExt, stream::BoxStream};
6use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
7use http_client::{AsyncBody, Method, Request, http};
8use parking_lot::Mutex;
9use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto};
10use std::sync::Arc;
11
12pub struct FakeServer {
13 peer: Arc<Peer>,
14 state: Arc<Mutex<FakeServerState>>,
15 user_id: u64,
16 executor: BackgroundExecutor,
17}
18
19#[derive(Default)]
20struct FakeServerState {
21 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
22 connection_id: Option<ConnectionId>,
23 forbid_connections: bool,
24 auth_count: usize,
25 access_token: usize,
26}
27
28impl FakeServer {
29 pub async fn for_client(
30 client_user_id: u64,
31 client: &Arc<Client>,
32 cx: &TestAppContext,
33 ) -> Self {
34 let server = Self {
35 peer: Peer::new(0),
36 state: Default::default(),
37 user_id: client_user_id,
38 executor: cx.executor(),
39 };
40
41 client.http_client().as_fake().replace_handler({
42 let state = server.state.clone();
43 move |old_handler, req| {
44 let state = state.clone();
45 let old_handler = old_handler.clone();
46 async move {
47 match (req.method(), req.uri().path()) {
48 (&Method::GET, "/client/users/me") => {
49 let credentials = parse_authorization_header(&req);
50 if credentials
51 != Some(Credentials {
52 user_id: client_user_id,
53 access_token: state.lock().access_token.to_string(),
54 })
55 {
56 return Ok(http_client::Response::builder()
57 .status(401)
58 .body("Unauthorized".into())
59 .unwrap());
60 }
61
62 Ok(http_client::Response::builder()
63 .status(200)
64 .body(
65 serde_json::to_string(&make_get_authenticated_user_response(
66 client_user_id as i32,
67 format!("user-{client_user_id}"),
68 ))
69 .unwrap()
70 .into(),
71 )
72 .unwrap())
73 }
74 _ => old_handler(req).await,
75 }
76 }
77 }
78 });
79 client
80 .override_authenticate({
81 let state = Arc::downgrade(&server.state);
82 move |cx| {
83 let state = state.clone();
84 cx.spawn(async move |_| {
85 let state = state.upgrade().context("server dropped")?;
86 let mut state = state.lock();
87 state.auth_count += 1;
88 let access_token = state.access_token.to_string();
89 Ok(Credentials {
90 user_id: client_user_id,
91 access_token,
92 })
93 })
94 }
95 })
96 .override_establish_connection({
97 let peer = Arc::downgrade(&server.peer);
98 let state = Arc::downgrade(&server.state);
99 move |credentials, cx| {
100 let peer = peer.clone();
101 let state = state.clone();
102 let credentials = credentials.clone();
103 cx.spawn(async move |cx| {
104 let state = state.upgrade().context("server dropped")?;
105 let peer = peer.upgrade().context("server dropped")?;
106 if state.lock().forbid_connections {
107 Err(EstablishConnectionError::Other(anyhow!(
108 "server is forbidding connections"
109 )))?
110 }
111
112 if credentials
113 != (Credentials {
114 user_id: client_user_id,
115 access_token: state.lock().access_token.to_string(),
116 })
117 {
118 Err(EstablishConnectionError::Unauthorized)?
119 }
120
121 let (client_conn, server_conn, _) =
122 Connection::in_memory(cx.background_executor().clone());
123 let (connection_id, io, incoming) =
124 peer.add_test_connection(server_conn, cx.background_executor().clone());
125 cx.background_spawn(io).detach();
126 {
127 let mut state = state.lock();
128 state.connection_id = Some(connection_id);
129 state.incoming = Some(incoming);
130 }
131 peer.send(
132 connection_id,
133 proto::Hello {
134 peer_id: Some(connection_id.into()),
135 },
136 )
137 .unwrap();
138
139 Ok(client_conn)
140 })
141 }
142 });
143
144 client
145 .connect(false, &cx.to_async())
146 .await
147 .into_response()
148 .unwrap();
149
150 server
151 }
152
153 pub fn disconnect(&self) {
154 if self.state.lock().connection_id.is_some() {
155 self.peer.disconnect(self.connection_id());
156 let mut state = self.state.lock();
157 state.connection_id.take();
158 state.incoming.take();
159 }
160 }
161
162 pub fn auth_count(&self) -> usize {
163 self.state.lock().auth_count
164 }
165
166 pub fn roll_access_token(&self) {
167 self.state.lock().access_token += 1;
168 }
169
170 pub fn forbid_connections(&self) {
171 self.state.lock().forbid_connections = true;
172 }
173
174 pub fn allow_connections(&self) {
175 self.state.lock().forbid_connections = false;
176 }
177
178 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
179 self.peer.send(self.connection_id(), message).unwrap();
180 }
181
182 #[allow(clippy::await_holding_lock)]
183 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
184 self.executor.start_waiting();
185
186 let message = self
187 .state
188 .lock()
189 .incoming
190 .as_mut()
191 .expect("not connected")
192 .next()
193 .await
194 .context("other half hung up")?;
195 self.executor.finish_waiting();
196 let type_name = message.payload_type_name();
197 let message = message.into_any();
198
199 if message.is::<TypedEnvelope<M>>() {
200 return Ok(*message.downcast().unwrap());
201 }
202
203 panic!(
204 "fake server received unexpected message type: {:?}",
205 type_name
206 );
207 }
208
209 pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
210 self.peer.respond(receipt, response).unwrap()
211 }
212
213 fn connection_id(&self) -> ConnectionId {
214 self.state.lock().connection_id.expect("not connected")
215 }
216
217 pub async fn build_user_store(
218 &self,
219 client: Arc<Client>,
220 cx: &mut TestAppContext,
221 ) -> Entity<UserStore> {
222 let user_store = cx.new(|cx| UserStore::new(client, cx));
223 assert_eq!(
224 self.receive::<proto::GetUsers>()
225 .await
226 .unwrap()
227 .payload
228 .user_ids,
229 &[self.user_id]
230 );
231 user_store
232 }
233}
234
235impl Drop for FakeServer {
236 fn drop(&mut self) {
237 self.disconnect();
238 }
239}
240
241pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
242 let mut auth_header = req
243 .headers()
244 .get(http::header::AUTHORIZATION)?
245 .to_str()
246 .ok()?
247 .split_whitespace();
248 let user_id = auth_header.next()?.parse().ok()?;
249 let access_token = auth_header.next()?;
250 Some(Credentials {
251 user_id,
252 access_token: access_token.to_string(),
253 })
254}
255
256pub fn make_get_authenticated_user_response(
257 user_id: i32,
258 github_login: String,
259) -> GetAuthenticatedUserResponse {
260 GetAuthenticatedUserResponse {
261 user: AuthenticatedUser {
262 id: user_id,
263 metrics_id: format!("metrics-id-{user_id}"),
264 avatar_url: "".to_string(),
265 github_login,
266 name: None,
267 is_staff: false,
268 accepted_tos_at: None,
269 },
270 feature_flags: vec![],
271 plan: PlanInfo {
272 plan: Plan::ZedPro,
273 subscription_period: None,
274 usage: CurrentUsage {
275 model_requests: UsageData {
276 used: 0,
277 limit: UsageLimit::Limited(500),
278 },
279 edit_predictions: UsageData {
280 used: 250,
281 limit: UsageLimit::Unlimited,
282 },
283 },
284 trial_started_at: None,
285 is_usage_based_billing_enabled: false,
286 is_account_too_young: false,
287 has_overdue_invoices: false,
288 },
289 }
290}