1use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
2use anyhow::{Context as _, Result, anyhow};
3use cloud_api_client::{
4 AuthenticatedUser, GetAuthenticatedUserResponse, KnownOrUnknown, Plan, PlanInfo,
5};
6use cloud_llm_client::{CurrentUsage, UsageData, UsageLimit};
7use futures::{StreamExt, stream::BoxStream};
8use gpui::{AppContext as _, Entity, TestAppContext};
9use http_client::{AsyncBody, Method, Request, http};
10use parking_lot::Mutex;
11use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto};
12use std::sync::Arc;
13
14pub struct FakeServer {
15 peer: Arc<Peer>,
16 state: Arc<Mutex<FakeServerState>>,
17 user_id: u64,
18}
19
20#[derive(Default)]
21struct FakeServerState {
22 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
23 connection_id: Option<ConnectionId>,
24 forbid_connections: bool,
25 auth_count: usize,
26 access_token: usize,
27}
28
29impl FakeServer {
30 pub async fn for_client(
31 client_user_id: u64,
32 client: &Arc<Client>,
33 cx: &TestAppContext,
34 ) -> Self {
35 let server = Self {
36 peer: Peer::new(0),
37 state: Default::default(),
38 user_id: client_user_id,
39 };
40
41 client.http_client().as_fake().replace_handler({
42 let state = server.state.clone();
43 move |old_handler, req| {
44 let state = state.clone();
45 let old_handler = old_handler.clone();
46 async move {
47 match (req.method(), req.uri().path()) {
48 (&Method::GET, "/client/users/me") => {
49 let credentials = parse_authorization_header(&req);
50 if credentials
51 != Some(Credentials {
52 user_id: client_user_id,
53 access_token: state.lock().access_token.to_string(),
54 })
55 {
56 return Ok(http_client::Response::builder()
57 .status(401)
58 .body("Unauthorized".into())
59 .unwrap());
60 }
61
62 Ok(http_client::Response::builder()
63 .status(200)
64 .body(
65 serde_json::to_string(&make_get_authenticated_user_response(
66 client_user_id as i32,
67 format!("user-{client_user_id}"),
68 ))
69 .unwrap()
70 .into(),
71 )
72 .unwrap())
73 }
74 _ => old_handler(req).await,
75 }
76 }
77 }
78 });
79 client
80 .override_authenticate({
81 let state = Arc::downgrade(&server.state);
82 move |cx| {
83 let state = state.clone();
84 cx.spawn(async move |_| {
85 let state = state.upgrade().context("server dropped")?;
86 let mut state = state.lock();
87 state.auth_count += 1;
88 let access_token = state.access_token.to_string();
89 Ok(Credentials {
90 user_id: client_user_id,
91 access_token,
92 })
93 })
94 }
95 })
96 .override_establish_connection({
97 let peer = Arc::downgrade(&server.peer);
98 let state = Arc::downgrade(&server.state);
99 move |credentials, cx| {
100 let peer = peer.clone();
101 let state = state.clone();
102 let credentials = credentials.clone();
103 cx.spawn(async move |cx| {
104 let state = state.upgrade().context("server dropped")?;
105 let peer = peer.upgrade().context("server dropped")?;
106 if state.lock().forbid_connections {
107 Err(EstablishConnectionError::Other(anyhow!(
108 "server is forbidding connections"
109 )))?
110 }
111
112 if credentials
113 != (Credentials {
114 user_id: client_user_id,
115 access_token: state.lock().access_token.to_string(),
116 })
117 {
118 Err(EstablishConnectionError::Unauthorized)?
119 }
120
121 let (client_conn, server_conn, _) =
122 Connection::in_memory(cx.background_executor().clone());
123 let (connection_id, io, incoming) =
124 peer.add_test_connection(server_conn, cx.background_executor().clone());
125 cx.background_spawn(io).detach();
126 {
127 let mut state = state.lock();
128 state.connection_id = Some(connection_id);
129 state.incoming = Some(incoming);
130 }
131 peer.send(
132 connection_id,
133 proto::Hello {
134 peer_id: Some(connection_id.into()),
135 },
136 )
137 .unwrap();
138
139 Ok(client_conn)
140 })
141 }
142 });
143
144 client
145 .connect(false, &cx.to_async())
146 .await
147 .into_response()
148 .unwrap();
149
150 server
151 }
152
153 pub fn disconnect(&self) {
154 if self.state.lock().connection_id.is_some() {
155 self.peer.disconnect(self.connection_id());
156 let mut state = self.state.lock();
157 state.connection_id.take();
158 state.incoming.take();
159 }
160 }
161
162 pub fn auth_count(&self) -> usize {
163 self.state.lock().auth_count
164 }
165
166 pub fn roll_access_token(&self) {
167 self.state.lock().access_token += 1;
168 }
169
170 pub fn forbid_connections(&self) {
171 self.state.lock().forbid_connections = true;
172 }
173
174 pub fn allow_connections(&self) {
175 self.state.lock().forbid_connections = false;
176 }
177
178 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
179 self.peer.send(self.connection_id(), message).unwrap();
180 }
181
182 #[allow(clippy::await_holding_lock)]
183 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
184 let message = self
185 .state
186 .lock()
187 .incoming
188 .as_mut()
189 .expect("not connected")
190 .next()
191 .await
192 .context("other half hung up")?;
193 let type_name = message.payload_type_name();
194 let message = message.into_any();
195
196 if message.is::<TypedEnvelope<M>>() {
197 return Ok(*message.downcast().unwrap());
198 }
199
200 panic!(
201 "fake server received unexpected message type: {:?}",
202 type_name
203 );
204 }
205
206 pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
207 self.peer.respond(receipt, response).unwrap()
208 }
209
210 fn connection_id(&self) -> ConnectionId {
211 self.state.lock().connection_id.expect("not connected")
212 }
213
214 pub async fn build_user_store(
215 &self,
216 client: Arc<Client>,
217 cx: &mut TestAppContext,
218 ) -> Entity<UserStore> {
219 let user_store = cx.new(|cx| UserStore::new(client, cx));
220 assert_eq!(
221 self.receive::<proto::GetUsers>()
222 .await
223 .unwrap()
224 .payload
225 .user_ids,
226 &[self.user_id]
227 );
228 user_store
229 }
230}
231
232impl Drop for FakeServer {
233 fn drop(&mut self) {
234 self.disconnect();
235 }
236}
237
238pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
239 let mut auth_header = req
240 .headers()
241 .get(http::header::AUTHORIZATION)?
242 .to_str()
243 .ok()?
244 .split_whitespace();
245 let user_id = auth_header.next()?.parse().ok()?;
246 let access_token = auth_header.next()?;
247 Some(Credentials {
248 user_id,
249 access_token: access_token.to_string(),
250 })
251}
252
253pub fn make_get_authenticated_user_response(
254 user_id: i32,
255 github_login: String,
256) -> GetAuthenticatedUserResponse {
257 GetAuthenticatedUserResponse {
258 user: AuthenticatedUser {
259 id: user_id,
260 metrics_id: format!("metrics-id-{user_id}"),
261 avatar_url: "".to_string(),
262 github_login,
263 name: None,
264 is_staff: false,
265 accepted_tos_at: None,
266 },
267 feature_flags: vec![],
268 plan: PlanInfo {
269 plan: KnownOrUnknown::Known(Plan::ZedPro),
270 subscription_period: None,
271 usage: CurrentUsage {
272 edit_predictions: UsageData {
273 used: 250,
274 limit: UsageLimit::Unlimited,
275 },
276 },
277 trial_started_at: None,
278 is_account_too_young: false,
279 has_overdue_invoices: false,
280 },
281 }
282}