test.rs

  1use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
  2use anyhow::{Context as _, Result, anyhow};
  3use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
  4use cloud_llm_client::{CurrentUsage, PlanV1, UsageData, UsageLimit};
  5use futures::{StreamExt, stream::BoxStream};
  6use gpui::{AppContext as _, BackgroundExecutor, Entity, TestAppContext};
  7use http_client::{AsyncBody, Method, Request, http};
  8use parking_lot::Mutex;
  9use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto};
 10use std::sync::Arc;
 11
 12pub struct FakeServer {
 13    peer: Arc<Peer>,
 14    state: Arc<Mutex<FakeServerState>>,
 15    user_id: u64,
 16    executor: BackgroundExecutor,
 17}
 18
 19#[derive(Default)]
 20struct FakeServerState {
 21    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 22    connection_id: Option<ConnectionId>,
 23    forbid_connections: bool,
 24    auth_count: usize,
 25    access_token: usize,
 26}
 27
 28impl FakeServer {
 29    pub async fn for_client(
 30        client_user_id: u64,
 31        client: &Arc<Client>,
 32        cx: &TestAppContext,
 33    ) -> Self {
 34        let server = Self {
 35            peer: Peer::new(0),
 36            state: Default::default(),
 37            user_id: client_user_id,
 38            executor: cx.executor(),
 39        };
 40
 41        client.http_client().as_fake().replace_handler({
 42            let state = server.state.clone();
 43            move |old_handler, req| {
 44                let state = state.clone();
 45                let old_handler = old_handler.clone();
 46                async move {
 47                    match (req.method(), req.uri().path()) {
 48                        (&Method::GET, "/client/users/me") => {
 49                            let credentials = parse_authorization_header(&req);
 50                            if credentials
 51                                != Some(Credentials {
 52                                    user_id: client_user_id,
 53                                    access_token: state.lock().access_token.to_string(),
 54                                })
 55                            {
 56                                return Ok(http_client::Response::builder()
 57                                    .status(401)
 58                                    .body("Unauthorized".into())
 59                                    .unwrap());
 60                            }
 61
 62                            Ok(http_client::Response::builder()
 63                                .status(200)
 64                                .body(
 65                                    serde_json::to_string(&make_get_authenticated_user_response(
 66                                        client_user_id as i32,
 67                                        format!("user-{client_user_id}"),
 68                                    ))
 69                                    .unwrap()
 70                                    .into(),
 71                                )
 72                                .unwrap())
 73                        }
 74                        _ => old_handler(req).await,
 75                    }
 76                }
 77            }
 78        });
 79        client
 80            .override_authenticate({
 81                let state = Arc::downgrade(&server.state);
 82                move |cx| {
 83                    let state = state.clone();
 84                    cx.spawn(async move |_| {
 85                        let state = state.upgrade().context("server dropped")?;
 86                        let mut state = state.lock();
 87                        state.auth_count += 1;
 88                        let access_token = state.access_token.to_string();
 89                        Ok(Credentials {
 90                            user_id: client_user_id,
 91                            access_token,
 92                        })
 93                    })
 94                }
 95            })
 96            .override_establish_connection({
 97                let peer = Arc::downgrade(&server.peer);
 98                let state = Arc::downgrade(&server.state);
 99                move |credentials, cx| {
100                    let peer = peer.clone();
101                    let state = state.clone();
102                    let credentials = credentials.clone();
103                    cx.spawn(async move |cx| {
104                        let state = state.upgrade().context("server dropped")?;
105                        let peer = peer.upgrade().context("server dropped")?;
106                        if state.lock().forbid_connections {
107                            Err(EstablishConnectionError::Other(anyhow!(
108                                "server is forbidding connections"
109                            )))?
110                        }
111
112                        if credentials
113                            != (Credentials {
114                                user_id: client_user_id,
115                                access_token: state.lock().access_token.to_string(),
116                            })
117                        {
118                            Err(EstablishConnectionError::Unauthorized)?
119                        }
120
121                        let (client_conn, server_conn, _) =
122                            Connection::in_memory(cx.background_executor().clone());
123                        let (connection_id, io, incoming) =
124                            peer.add_test_connection(server_conn, cx.background_executor().clone());
125                        cx.background_spawn(io).detach();
126                        {
127                            let mut state = state.lock();
128                            state.connection_id = Some(connection_id);
129                            state.incoming = Some(incoming);
130                        }
131                        peer.send(
132                            connection_id,
133                            proto::Hello {
134                                peer_id: Some(connection_id.into()),
135                            },
136                        )
137                        .unwrap();
138
139                        Ok(client_conn)
140                    })
141                }
142            });
143
144        client
145            .connect(false, &cx.to_async())
146            .await
147            .into_response()
148            .unwrap();
149
150        server
151    }
152
153    pub fn disconnect(&self) {
154        if self.state.lock().connection_id.is_some() {
155            self.peer.disconnect(self.connection_id());
156            let mut state = self.state.lock();
157            state.connection_id.take();
158            state.incoming.take();
159        }
160    }
161
162    pub fn auth_count(&self) -> usize {
163        self.state.lock().auth_count
164    }
165
166    pub fn roll_access_token(&self) {
167        self.state.lock().access_token += 1;
168    }
169
170    pub fn forbid_connections(&self) {
171        self.state.lock().forbid_connections = true;
172    }
173
174    pub fn allow_connections(&self) {
175        self.state.lock().forbid_connections = false;
176    }
177
178    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
179        self.peer.send(self.connection_id(), message).unwrap();
180    }
181
182    #[allow(clippy::await_holding_lock)]
183    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
184        self.executor.start_waiting();
185
186        let message = self
187            .state
188            .lock()
189            .incoming
190            .as_mut()
191            .expect("not connected")
192            .next()
193            .await
194            .context("other half hung up")?;
195        self.executor.finish_waiting();
196        let type_name = message.payload_type_name();
197        let message = message.into_any();
198
199        if message.is::<TypedEnvelope<M>>() {
200            return Ok(*message.downcast().unwrap());
201        }
202
203        panic!(
204            "fake server received unexpected message type: {:?}",
205            type_name
206        );
207    }
208
209    pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
210        self.peer.respond(receipt, response).unwrap()
211    }
212
213    fn connection_id(&self) -> ConnectionId {
214        self.state.lock().connection_id.expect("not connected")
215    }
216
217    pub async fn build_user_store(
218        &self,
219        client: Arc<Client>,
220        cx: &mut TestAppContext,
221    ) -> Entity<UserStore> {
222        let user_store = cx.new(|cx| UserStore::new(client, cx));
223        assert_eq!(
224            self.receive::<proto::GetUsers>()
225                .await
226                .unwrap()
227                .payload
228                .user_ids,
229            &[self.user_id]
230        );
231        user_store
232    }
233}
234
235impl Drop for FakeServer {
236    fn drop(&mut self) {
237        self.disconnect();
238    }
239}
240
241pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
242    let mut auth_header = req
243        .headers()
244        .get(http::header::AUTHORIZATION)?
245        .to_str()
246        .ok()?
247        .split_whitespace();
248    let user_id = auth_header.next()?.parse().ok()?;
249    let access_token = auth_header.next()?;
250    Some(Credentials {
251        user_id,
252        access_token: access_token.to_string(),
253    })
254}
255
256pub fn make_get_authenticated_user_response(
257    user_id: i32,
258    github_login: String,
259) -> GetAuthenticatedUserResponse {
260    GetAuthenticatedUserResponse {
261        user: AuthenticatedUser {
262            id: user_id,
263            metrics_id: format!("metrics-id-{user_id}"),
264            avatar_url: "".to_string(),
265            github_login,
266            name: None,
267            is_staff: false,
268            accepted_tos_at: None,
269        },
270        feature_flags: vec![],
271        plan: PlanInfo {
272            plan: PlanV1::ZedPro,
273            plan_v2: None,
274            subscription_period: None,
275            usage: CurrentUsage {
276                model_requests: UsageData {
277                    used: 0,
278                    limit: UsageLimit::Limited(500),
279                },
280                edit_predictions: UsageData {
281                    used: 250,
282                    limit: UsageLimit::Unlimited,
283                },
284            },
285            trial_started_at: None,
286            is_usage_based_billing_enabled: false,
287            is_account_too_young: false,
288            has_overdue_invoices: false,
289        },
290    }
291}