1use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
2use anyhow::{Context as _, Result, anyhow};
3use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
4use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
5use futures::{StreamExt, stream::BoxStream};
6use gpui::{AppContext as _, Entity, TestAppContext};
7use http_client::{AsyncBody, Method, Request, http};
8use parking_lot::Mutex;
9use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto};
10use std::sync::Arc;
11
12pub struct FakeServer {
13 peer: Arc<Peer>,
14 state: Arc<Mutex<FakeServerState>>,
15 user_id: u64,
16}
17
18#[derive(Default)]
19struct FakeServerState {
20 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
21 connection_id: Option<ConnectionId>,
22 forbid_connections: bool,
23 auth_count: usize,
24 access_token: usize,
25}
26
27impl FakeServer {
28 pub async fn for_client(
29 client_user_id: u64,
30 client: &Arc<Client>,
31 cx: &TestAppContext,
32 ) -> Self {
33 let server = Self {
34 peer: Peer::new(0),
35 state: Default::default(),
36 user_id: client_user_id,
37 };
38
39 client.http_client().as_fake().replace_handler({
40 let state = server.state.clone();
41 move |old_handler, req| {
42 let state = state.clone();
43 let old_handler = old_handler.clone();
44 async move {
45 match (req.method(), req.uri().path()) {
46 (&Method::GET, "/client/users/me") => {
47 let credentials = parse_authorization_header(&req);
48 if credentials
49 != Some(Credentials {
50 user_id: client_user_id,
51 access_token: state.lock().access_token.to_string(),
52 })
53 {
54 return Ok(http_client::Response::builder()
55 .status(401)
56 .body("Unauthorized".into())
57 .unwrap());
58 }
59
60 Ok(http_client::Response::builder()
61 .status(200)
62 .body(
63 serde_json::to_string(&make_get_authenticated_user_response(
64 client_user_id as i32,
65 format!("user-{client_user_id}"),
66 ))
67 .unwrap()
68 .into(),
69 )
70 .unwrap())
71 }
72 _ => old_handler(req).await,
73 }
74 }
75 }
76 });
77 client
78 .override_authenticate({
79 let state = Arc::downgrade(&server.state);
80 move |cx| {
81 let state = state.clone();
82 cx.spawn(async move |_| {
83 let state = state.upgrade().context("server dropped")?;
84 let mut state = state.lock();
85 state.auth_count += 1;
86 let access_token = state.access_token.to_string();
87 Ok(Credentials {
88 user_id: client_user_id,
89 access_token,
90 })
91 })
92 }
93 })
94 .override_establish_connection({
95 let peer = Arc::downgrade(&server.peer);
96 let state = Arc::downgrade(&server.state);
97 move |credentials, cx| {
98 let peer = peer.clone();
99 let state = state.clone();
100 let credentials = credentials.clone();
101 cx.spawn(async move |cx| {
102 let state = state.upgrade().context("server dropped")?;
103 let peer = peer.upgrade().context("server dropped")?;
104 if state.lock().forbid_connections {
105 Err(EstablishConnectionError::Other(anyhow!(
106 "server is forbidding connections"
107 )))?
108 }
109
110 if credentials
111 != (Credentials {
112 user_id: client_user_id,
113 access_token: state.lock().access_token.to_string(),
114 })
115 {
116 Err(EstablishConnectionError::Unauthorized)?
117 }
118
119 let (client_conn, server_conn, _) =
120 Connection::in_memory(cx.background_executor().clone());
121 let (connection_id, io, incoming) =
122 peer.add_test_connection(server_conn, cx.background_executor().clone());
123 cx.background_spawn(io).detach();
124 {
125 let mut state = state.lock();
126 state.connection_id = Some(connection_id);
127 state.incoming = Some(incoming);
128 }
129 peer.send(
130 connection_id,
131 proto::Hello {
132 peer_id: Some(connection_id.into()),
133 },
134 )
135 .unwrap();
136
137 Ok(client_conn)
138 })
139 }
140 });
141
142 client
143 .connect(false, &cx.to_async())
144 .await
145 .into_response()
146 .unwrap();
147
148 server
149 }
150
151 pub fn disconnect(&self) {
152 if self.state.lock().connection_id.is_some() {
153 self.peer.disconnect(self.connection_id());
154 let mut state = self.state.lock();
155 state.connection_id.take();
156 state.incoming.take();
157 }
158 }
159
160 pub fn auth_count(&self) -> usize {
161 self.state.lock().auth_count
162 }
163
164 pub fn roll_access_token(&self) {
165 self.state.lock().access_token += 1;
166 }
167
168 pub fn forbid_connections(&self) {
169 self.state.lock().forbid_connections = true;
170 }
171
172 pub fn allow_connections(&self) {
173 self.state.lock().forbid_connections = false;
174 }
175
176 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
177 self.peer.send(self.connection_id(), message).unwrap();
178 }
179
180 #[allow(clippy::await_holding_lock)]
181 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
182 let message = self
183 .state
184 .lock()
185 .incoming
186 .as_mut()
187 .expect("not connected")
188 .next()
189 .await
190 .context("other half hung up")?;
191 let type_name = message.payload_type_name();
192 let message = message.into_any();
193
194 if message.is::<TypedEnvelope<M>>() {
195 return Ok(*message.downcast().unwrap());
196 }
197
198 panic!(
199 "fake server received unexpected message type: {:?}",
200 type_name
201 );
202 }
203
204 pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
205 self.peer.respond(receipt, response).unwrap()
206 }
207
208 fn connection_id(&self) -> ConnectionId {
209 self.state.lock().connection_id.expect("not connected")
210 }
211
212 pub async fn build_user_store(
213 &self,
214 client: Arc<Client>,
215 cx: &mut TestAppContext,
216 ) -> Entity<UserStore> {
217 let user_store = cx.new(|cx| UserStore::new(client, cx));
218 assert_eq!(
219 self.receive::<proto::GetUsers>()
220 .await
221 .unwrap()
222 .payload
223 .user_ids,
224 &[self.user_id]
225 );
226 user_store
227 }
228}
229
230impl Drop for FakeServer {
231 fn drop(&mut self) {
232 self.disconnect();
233 }
234}
235
236pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
237 let mut auth_header = req
238 .headers()
239 .get(http::header::AUTHORIZATION)?
240 .to_str()
241 .ok()?
242 .split_whitespace();
243 let user_id = auth_header.next()?.parse().ok()?;
244 let access_token = auth_header.next()?;
245 Some(Credentials {
246 user_id,
247 access_token: access_token.to_string(),
248 })
249}
250
251pub fn make_get_authenticated_user_response(
252 user_id: i32,
253 github_login: String,
254) -> GetAuthenticatedUserResponse {
255 GetAuthenticatedUserResponse {
256 user: AuthenticatedUser {
257 id: user_id,
258 metrics_id: format!("metrics-id-{user_id}"),
259 avatar_url: "".to_string(),
260 github_login,
261 name: None,
262 is_staff: false,
263 accepted_tos_at: None,
264 },
265 feature_flags: vec![],
266 plan: PlanInfo {
267 plan: Plan::ZedPro,
268 subscription_period: None,
269 usage: CurrentUsage {
270 model_requests: UsageData {
271 used: 0,
272 limit: UsageLimit::Limited(500),
273 },
274 edit_predictions: UsageData {
275 used: 250,
276 limit: UsageLimit::Unlimited,
277 },
278 },
279 trial_started_at: None,
280 is_usage_based_billing_enabled: false,
281 is_account_too_young: false,
282 has_overdue_invoices: false,
283 },
284 }
285}