test.rs

  1use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
  2use anyhow::{Context as _, Result, anyhow};
  3use chrono::Duration;
  4use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
  5use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
  6use futures::{StreamExt, stream::BoxStream};
  7use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
  8use http_client::{AsyncBody, Method, Request, http};
  9use parking_lot::Mutex;
 10use rpc::{
 11    ConnectionId, Peer, Receipt, TypedEnvelope,
 12    proto::{self, GetPrivateUserInfo, GetPrivateUserInfoResponse},
 13};
 14use std::sync::Arc;
 15
 16pub struct FakeServer {
 17    peer: Arc<Peer>,
 18    state: Arc<Mutex<FakeServerState>>,
 19    user_id: u64,
 20    executor: BackgroundExecutor,
 21}
 22
 23#[derive(Default)]
 24struct FakeServerState {
 25    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 26    connection_id: Option<ConnectionId>,
 27    forbid_connections: bool,
 28    auth_count: usize,
 29    access_token: usize,
 30}
 31
 32impl FakeServer {
 33    pub async fn for_client(
 34        client_user_id: u64,
 35        client: &Arc<Client>,
 36        cx: &TestAppContext,
 37    ) -> Self {
 38        let server = Self {
 39            peer: Peer::new(0),
 40            state: Default::default(),
 41            user_id: client_user_id,
 42            executor: cx.executor(),
 43        };
 44
 45        client.http_client().as_fake().replace_handler({
 46            let state = server.state.clone();
 47            move |old_handler, req| {
 48                let state = state.clone();
 49                let old_handler = old_handler.clone();
 50                async move {
 51                    match (req.method(), req.uri().path()) {
 52                        (&Method::GET, "/client/users/me") => {
 53                            let credentials = parse_authorization_header(&req);
 54                            if credentials
 55                                != Some(Credentials {
 56                                    user_id: client_user_id,
 57                                    access_token: state.lock().access_token.to_string(),
 58                                })
 59                            {
 60                                return Ok(http_client::Response::builder()
 61                                    .status(401)
 62                                    .body("Unauthorized".into())
 63                                    .unwrap());
 64                            }
 65
 66                            Ok(http_client::Response::builder()
 67                                .status(200)
 68                                .body(
 69                                    serde_json::to_string(&make_get_authenticated_user_response(
 70                                        client_user_id as i32,
 71                                        format!("user-{client_user_id}"),
 72                                    ))
 73                                    .unwrap()
 74                                    .into(),
 75                                )
 76                                .unwrap())
 77                        }
 78                        _ => old_handler(req).await,
 79                    }
 80                }
 81            }
 82        });
 83        client
 84            .override_authenticate({
 85                let state = Arc::downgrade(&server.state);
 86                move |cx| {
 87                    let state = state.clone();
 88                    cx.spawn(async move |_| {
 89                        let state = state.upgrade().context("server dropped")?;
 90                        let mut state = state.lock();
 91                        state.auth_count += 1;
 92                        let access_token = state.access_token.to_string();
 93                        Ok(Credentials {
 94                            user_id: client_user_id,
 95                            access_token,
 96                        })
 97                    })
 98                }
 99            })
100            .override_establish_connection({
101                let peer = Arc::downgrade(&server.peer);
102                let state = Arc::downgrade(&server.state);
103                move |credentials, cx| {
104                    let peer = peer.clone();
105                    let state = state.clone();
106                    let credentials = credentials.clone();
107                    cx.spawn(async move |cx| {
108                        let state = state.upgrade().context("server dropped")?;
109                        let peer = peer.upgrade().context("server dropped")?;
110                        if state.lock().forbid_connections {
111                            Err(EstablishConnectionError::Other(anyhow!(
112                                "server is forbidding connections"
113                            )))?
114                        }
115
116                        if credentials
117                            != (Credentials {
118                                user_id: client_user_id,
119                                access_token: state.lock().access_token.to_string(),
120                            })
121                        {
122                            Err(EstablishConnectionError::Unauthorized)?
123                        }
124
125                        let (client_conn, server_conn, _) =
126                            Connection::in_memory(cx.background_executor().clone());
127                        let (connection_id, io, incoming) =
128                            peer.add_test_connection(server_conn, cx.background_executor().clone());
129                        cx.background_spawn(io).detach();
130                        {
131                            let mut state = state.lock();
132                            state.connection_id = Some(connection_id);
133                            state.incoming = Some(incoming);
134                        }
135                        peer.send(
136                            connection_id,
137                            proto::Hello {
138                                peer_id: Some(connection_id.into()),
139                            },
140                        )
141                        .unwrap();
142
143                        Ok(client_conn)
144                    })
145                }
146            });
147
148        client
149            .connect(false, &cx.to_async())
150            .await
151            .into_response()
152            .unwrap();
153
154        server
155    }
156
157    pub fn disconnect(&self) {
158        if self.state.lock().connection_id.is_some() {
159            self.peer.disconnect(self.connection_id());
160            let mut state = self.state.lock();
161            state.connection_id.take();
162            state.incoming.take();
163        }
164    }
165
166    pub fn auth_count(&self) -> usize {
167        self.state.lock().auth_count
168    }
169
170    pub fn roll_access_token(&self) {
171        self.state.lock().access_token += 1;
172    }
173
174    pub fn forbid_connections(&self) {
175        self.state.lock().forbid_connections = true;
176    }
177
178    pub fn allow_connections(&self) {
179        self.state.lock().forbid_connections = false;
180    }
181
182    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
183        self.peer.send(self.connection_id(), message).unwrap();
184    }
185
186    #[allow(clippy::await_holding_lock)]
187    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
188        self.executor.start_waiting();
189
190        loop {
191            let message = self
192                .state
193                .lock()
194                .incoming
195                .as_mut()
196                .expect("not connected")
197                .next()
198                .await
199                .context("other half hung up")?;
200            self.executor.finish_waiting();
201            let type_name = message.payload_type_name();
202            let message = message.into_any();
203
204            if message.is::<TypedEnvelope<M>>() {
205                return Ok(*message.downcast().unwrap());
206            }
207
208            let accepted_tos_at = chrono::Utc::now()
209                .checked_sub_signed(Duration::hours(5))
210                .expect("failed to build accepted_tos_at")
211                .timestamp() as u64;
212
213            if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
214                self.respond(
215                    message
216                        .downcast::<TypedEnvelope<GetPrivateUserInfo>>()
217                        .unwrap()
218                        .receipt(),
219                    GetPrivateUserInfoResponse {
220                        metrics_id: "the-metrics-id".into(),
221                        staff: false,
222                        flags: Default::default(),
223                        accepted_tos_at: Some(accepted_tos_at),
224                    },
225                );
226                continue;
227            }
228
229            panic!(
230                "fake server received unexpected message type: {:?}",
231                type_name
232            );
233        }
234    }
235
236    pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
237        self.peer.respond(receipt, response).unwrap()
238    }
239
240    fn connection_id(&self) -> ConnectionId {
241        self.state.lock().connection_id.expect("not connected")
242    }
243
244    pub async fn build_user_store(
245        &self,
246        client: Arc<Client>,
247        cx: &mut TestAppContext,
248    ) -> Entity<UserStore> {
249        let user_store = cx.new(|cx| UserStore::new(client, cx));
250        assert_eq!(
251            self.receive::<proto::GetUsers>()
252                .await
253                .unwrap()
254                .payload
255                .user_ids,
256            &[self.user_id]
257        );
258        user_store
259    }
260}
261
262impl Drop for FakeServer {
263    fn drop(&mut self) {
264        self.disconnect();
265    }
266}
267
268pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
269    let mut auth_header = req
270        .headers()
271        .get(http::header::AUTHORIZATION)?
272        .to_str()
273        .ok()?
274        .split_whitespace();
275    let user_id = auth_header.next()?.parse().ok()?;
276    let access_token = auth_header.next()?;
277    Some(Credentials {
278        user_id,
279        access_token: access_token.to_string(),
280    })
281}
282
283pub fn make_get_authenticated_user_response(
284    user_id: i32,
285    github_login: String,
286) -> GetAuthenticatedUserResponse {
287    GetAuthenticatedUserResponse {
288        user: AuthenticatedUser {
289            id: user_id,
290            metrics_id: format!("metrics-id-{user_id}"),
291            avatar_url: "".to_string(),
292            github_login,
293            name: None,
294            is_staff: false,
295            accepted_tos_at: None,
296        },
297        feature_flags: vec![],
298        plan: PlanInfo {
299            plan: Plan::ZedPro,
300            subscription_period: None,
301            usage: CurrentUsage {
302                model_requests: UsageData {
303                    used: 0,
304                    limit: UsageLimit::Limited(500),
305                },
306                edit_predictions: UsageData {
307                    used: 250,
308                    limit: UsageLimit::Unlimited,
309                },
310            },
311            trial_started_at: None,
312            is_usage_based_billing_enabled: false,
313            is_account_too_young: false,
314            has_overdue_invoices: false,
315        },
316    }
317}