1use std::collections::BTreeMap;
2use std::sync::Arc;
3
4use anyhow::{Context as _, Result, anyhow};
5use cloud_api_client::{
6 AuthenticatedUser, GetAuthenticatedUserResponse, KnownOrUnknown, Plan, PlanInfo,
7};
8use cloud_llm_client::{CurrentUsage, UsageData, UsageLimit};
9use futures::{StreamExt, stream::BoxStream};
10use gpui::{AppContext as _, Entity, TestAppContext};
11use http_client::{AsyncBody, Method, Request, http};
12use parking_lot::Mutex;
13use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto};
14
15use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
16
17pub struct FakeServer {
18 peer: Arc<Peer>,
19 state: Arc<Mutex<FakeServerState>>,
20 user_id: u64,
21}
22
23#[derive(Default)]
24struct FakeServerState {
25 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
26 connection_id: Option<ConnectionId>,
27 forbid_connections: bool,
28 auth_count: usize,
29 access_token: usize,
30}
31
32impl FakeServer {
33 pub async fn for_client(
34 client_user_id: u64,
35 client: &Arc<Client>,
36 cx: &TestAppContext,
37 ) -> Self {
38 let server = Self {
39 peer: Peer::new(0),
40 state: Default::default(),
41 user_id: client_user_id,
42 };
43
44 client.http_client().as_fake().replace_handler({
45 let state = server.state.clone();
46 move |old_handler, req| {
47 let state = state.clone();
48 let old_handler = old_handler.clone();
49 async move {
50 match (req.method(), req.uri().path()) {
51 (&Method::GET, "/client/users/me") => {
52 let credentials = parse_authorization_header(&req);
53 if credentials
54 != Some(Credentials {
55 user_id: client_user_id,
56 access_token: state.lock().access_token.to_string(),
57 })
58 {
59 return Ok(http_client::Response::builder()
60 .status(401)
61 .body("Unauthorized".into())
62 .unwrap());
63 }
64
65 Ok(http_client::Response::builder()
66 .status(200)
67 .body(
68 serde_json::to_string(&make_get_authenticated_user_response(
69 client_user_id as i32,
70 format!("user-{client_user_id}"),
71 ))
72 .unwrap()
73 .into(),
74 )
75 .unwrap())
76 }
77 _ => old_handler(req).await,
78 }
79 }
80 }
81 });
82 client
83 .override_authenticate({
84 let state = Arc::downgrade(&server.state);
85 move |cx| {
86 let state = state.clone();
87 cx.spawn(async move |_| {
88 let state = state.upgrade().context("server dropped")?;
89 let mut state = state.lock();
90 state.auth_count += 1;
91 let access_token = state.access_token.to_string();
92 Ok(Credentials {
93 user_id: client_user_id,
94 access_token,
95 })
96 })
97 }
98 })
99 .override_establish_connection({
100 let peer = Arc::downgrade(&server.peer);
101 let state = Arc::downgrade(&server.state);
102 move |credentials, cx| {
103 let peer = peer.clone();
104 let state = state.clone();
105 let credentials = credentials.clone();
106 cx.spawn(async move |cx| {
107 let state = state.upgrade().context("server dropped")?;
108 let peer = peer.upgrade().context("server dropped")?;
109 if state.lock().forbid_connections {
110 Err(EstablishConnectionError::Other(anyhow!(
111 "server is forbidding connections"
112 )))?
113 }
114
115 if credentials
116 != (Credentials {
117 user_id: client_user_id,
118 access_token: state.lock().access_token.to_string(),
119 })
120 {
121 Err(EstablishConnectionError::Unauthorized)?
122 }
123
124 let (client_conn, server_conn, _) =
125 Connection::in_memory(cx.background_executor().clone());
126 let (connection_id, io, incoming) =
127 peer.add_test_connection(server_conn, cx.background_executor().clone());
128 cx.background_spawn(io).detach();
129 {
130 let mut state = state.lock();
131 state.connection_id = Some(connection_id);
132 state.incoming = Some(incoming);
133 }
134 peer.send(
135 connection_id,
136 proto::Hello {
137 peer_id: Some(connection_id.into()),
138 },
139 )
140 .unwrap();
141
142 Ok(client_conn)
143 })
144 }
145 });
146
147 client
148 .connect(false, &cx.to_async())
149 .await
150 .into_response()
151 .unwrap();
152
153 server
154 }
155
156 pub fn disconnect(&self) {
157 if self.state.lock().connection_id.is_some() {
158 self.peer.disconnect(self.connection_id());
159 let mut state = self.state.lock();
160 state.connection_id.take();
161 state.incoming.take();
162 }
163 }
164
165 pub fn auth_count(&self) -> usize {
166 self.state.lock().auth_count
167 }
168
169 pub fn roll_access_token(&self) {
170 self.state.lock().access_token += 1;
171 }
172
173 pub fn forbid_connections(&self) {
174 self.state.lock().forbid_connections = true;
175 }
176
177 pub fn allow_connections(&self) {
178 self.state.lock().forbid_connections = false;
179 }
180
181 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
182 self.peer.send(self.connection_id(), message).unwrap();
183 }
184
185 #[allow(clippy::await_holding_lock)]
186 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
187 let message = self
188 .state
189 .lock()
190 .incoming
191 .as_mut()
192 .expect("not connected")
193 .next()
194 .await
195 .context("other half hung up")?;
196 let type_name = message.payload_type_name();
197 let message = message.into_any();
198
199 if message.is::<TypedEnvelope<M>>() {
200 return Ok(*message.downcast().unwrap());
201 }
202
203 panic!(
204 "fake server received unexpected message type: {:?}",
205 type_name
206 );
207 }
208
209 pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
210 self.peer.respond(receipt, response).unwrap()
211 }
212
213 fn connection_id(&self) -> ConnectionId {
214 self.state.lock().connection_id.expect("not connected")
215 }
216
217 pub async fn build_user_store(
218 &self,
219 client: Arc<Client>,
220 cx: &mut TestAppContext,
221 ) -> Entity<UserStore> {
222 let user_store = cx.new(|cx| UserStore::new(client, cx));
223 assert_eq!(
224 self.receive::<proto::GetUsers>()
225 .await
226 .unwrap()
227 .payload
228 .user_ids,
229 &[self.user_id]
230 );
231 user_store
232 }
233}
234
235impl Drop for FakeServer {
236 fn drop(&mut self) {
237 self.disconnect();
238 }
239}
240
241pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
242 let mut auth_header = req
243 .headers()
244 .get(http::header::AUTHORIZATION)?
245 .to_str()
246 .ok()?
247 .split_whitespace();
248 let user_id = auth_header.next()?.parse().ok()?;
249 let access_token = auth_header.next()?;
250 Some(Credentials {
251 user_id,
252 access_token: access_token.to_string(),
253 })
254}
255
256pub fn make_get_authenticated_user_response(
257 user_id: i32,
258 github_login: String,
259) -> GetAuthenticatedUserResponse {
260 GetAuthenticatedUserResponse {
261 user: AuthenticatedUser {
262 id: user_id,
263 metrics_id: format!("metrics-id-{user_id}"),
264 avatar_url: "".to_string(),
265 github_login,
266 name: None,
267 is_staff: false,
268 accepted_tos_at: None,
269 },
270 feature_flags: vec![],
271 organizations: vec![],
272 plans_by_organization: BTreeMap::new(),
273 plan: PlanInfo {
274 plan: KnownOrUnknown::Known(Plan::ZedPro),
275 subscription_period: None,
276 usage: CurrentUsage {
277 edit_predictions: UsageData {
278 used: 250,
279 limit: UsageLimit::Unlimited,
280 },
281 },
282 trial_started_at: None,
283 is_account_too_young: false,
284 has_overdue_invoices: false,
285 },
286 }
287}