test.rs

  1use std::collections::BTreeMap;
  2use std::sync::Arc;
  3
  4use anyhow::{Context as _, Result, anyhow};
  5use cloud_api_client::{
  6    AuthenticatedUser, GetAuthenticatedUserResponse, KnownOrUnknown, Plan, PlanInfo,
  7};
  8use cloud_llm_client::{CurrentUsage, UsageData, UsageLimit};
  9use futures::{StreamExt, stream::BoxStream};
 10use gpui::{AppContext as _, Entity, TestAppContext};
 11use http_client::{AsyncBody, Method, Request, http};
 12use parking_lot::Mutex;
 13use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto};
 14
 15use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
 16
 17pub struct FakeServer {
 18    peer: Arc<Peer>,
 19    state: Arc<Mutex<FakeServerState>>,
 20    user_id: u64,
 21}
 22
 23#[derive(Default)]
 24struct FakeServerState {
 25    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 26    connection_id: Option<ConnectionId>,
 27    forbid_connections: bool,
 28    auth_count: usize,
 29    access_token: usize,
 30}
 31
 32impl FakeServer {
 33    pub async fn for_client(
 34        client_user_id: u64,
 35        client: &Arc<Client>,
 36        cx: &TestAppContext,
 37    ) -> Self {
 38        let server = Self {
 39            peer: Peer::new(0),
 40            state: Default::default(),
 41            user_id: client_user_id,
 42        };
 43
 44        client.http_client().as_fake().replace_handler({
 45            let state = server.state.clone();
 46            move |old_handler, req| {
 47                let state = state.clone();
 48                let old_handler = old_handler.clone();
 49                async move {
 50                    match (req.method(), req.uri().path()) {
 51                        (&Method::GET, "/client/users/me") => {
 52                            let credentials = parse_authorization_header(&req);
 53                            if credentials
 54                                != Some(Credentials {
 55                                    user_id: client_user_id,
 56                                    access_token: state.lock().access_token.to_string(),
 57                                })
 58                            {
 59                                return Ok(http_client::Response::builder()
 60                                    .status(401)
 61                                    .body("Unauthorized".into())
 62                                    .unwrap());
 63                            }
 64
 65                            Ok(http_client::Response::builder()
 66                                .status(200)
 67                                .body(
 68                                    serde_json::to_string(&make_get_authenticated_user_response(
 69                                        client_user_id as i32,
 70                                        format!("user-{client_user_id}"),
 71                                    ))
 72                                    .unwrap()
 73                                    .into(),
 74                                )
 75                                .unwrap())
 76                        }
 77                        _ => old_handler(req).await,
 78                    }
 79                }
 80            }
 81        });
 82        client
 83            .override_authenticate({
 84                let state = Arc::downgrade(&server.state);
 85                move |cx| {
 86                    let state = state.clone();
 87                    cx.spawn(async move |_| {
 88                        let state = state.upgrade().context("server dropped")?;
 89                        let mut state = state.lock();
 90                        state.auth_count += 1;
 91                        let access_token = state.access_token.to_string();
 92                        Ok(Credentials {
 93                            user_id: client_user_id,
 94                            access_token,
 95                        })
 96                    })
 97                }
 98            })
 99            .override_establish_connection({
100                let peer = Arc::downgrade(&server.peer);
101                let state = Arc::downgrade(&server.state);
102                move |credentials, cx| {
103                    let peer = peer.clone();
104                    let state = state.clone();
105                    let credentials = credentials.clone();
106                    cx.spawn(async move |cx| {
107                        let state = state.upgrade().context("server dropped")?;
108                        let peer = peer.upgrade().context("server dropped")?;
109                        if state.lock().forbid_connections {
110                            Err(EstablishConnectionError::Other(anyhow!(
111                                "server is forbidding connections"
112                            )))?
113                        }
114
115                        if credentials
116                            != (Credentials {
117                                user_id: client_user_id,
118                                access_token: state.lock().access_token.to_string(),
119                            })
120                        {
121                            Err(EstablishConnectionError::Unauthorized)?
122                        }
123
124                        let (client_conn, server_conn, _) =
125                            Connection::in_memory(cx.background_executor().clone());
126                        let (connection_id, io, incoming) =
127                            peer.add_test_connection(server_conn, cx.background_executor().clone());
128                        cx.background_spawn(io).detach();
129                        {
130                            let mut state = state.lock();
131                            state.connection_id = Some(connection_id);
132                            state.incoming = Some(incoming);
133                        }
134                        peer.send(
135                            connection_id,
136                            proto::Hello {
137                                peer_id: Some(connection_id.into()),
138                            },
139                        )
140                        .unwrap();
141
142                        Ok(client_conn)
143                    })
144                }
145            });
146
147        client
148            .connect(false, &cx.to_async())
149            .await
150            .into_response()
151            .unwrap();
152
153        server
154    }
155
156    pub fn disconnect(&self) {
157        if self.state.lock().connection_id.is_some() {
158            self.peer.disconnect(self.connection_id());
159            let mut state = self.state.lock();
160            state.connection_id.take();
161            state.incoming.take();
162        }
163    }
164
165    pub fn auth_count(&self) -> usize {
166        self.state.lock().auth_count
167    }
168
169    pub fn roll_access_token(&self) {
170        self.state.lock().access_token += 1;
171    }
172
173    pub fn forbid_connections(&self) {
174        self.state.lock().forbid_connections = true;
175    }
176
177    pub fn allow_connections(&self) {
178        self.state.lock().forbid_connections = false;
179    }
180
181    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
182        self.peer.send(self.connection_id(), message).unwrap();
183    }
184
185    #[allow(clippy::await_holding_lock)]
186    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
187        let message = self
188            .state
189            .lock()
190            .incoming
191            .as_mut()
192            .expect("not connected")
193            .next()
194            .await
195            .context("other half hung up")?;
196        let type_name = message.payload_type_name();
197        let message = message.into_any();
198
199        if message.is::<TypedEnvelope<M>>() {
200            return Ok(*message.downcast().unwrap());
201        }
202
203        panic!(
204            "fake server received unexpected message type: {:?}",
205            type_name
206        );
207    }
208
209    pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
210        self.peer.respond(receipt, response).unwrap()
211    }
212
213    fn connection_id(&self) -> ConnectionId {
214        self.state.lock().connection_id.expect("not connected")
215    }
216
217    pub async fn build_user_store(
218        &self,
219        client: Arc<Client>,
220        cx: &mut TestAppContext,
221    ) -> Entity<UserStore> {
222        let user_store = cx.new(|cx| UserStore::new(client, cx));
223        assert_eq!(
224            self.receive::<proto::GetUsers>()
225                .await
226                .unwrap()
227                .payload
228                .user_ids,
229            &[self.user_id]
230        );
231        user_store
232    }
233}
234
235impl Drop for FakeServer {
236    fn drop(&mut self) {
237        self.disconnect();
238    }
239}
240
241pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
242    let mut auth_header = req
243        .headers()
244        .get(http::header::AUTHORIZATION)?
245        .to_str()
246        .ok()?
247        .split_whitespace();
248    let user_id = auth_header.next()?.parse().ok()?;
249    let access_token = auth_header.next()?;
250    Some(Credentials {
251        user_id,
252        access_token: access_token.to_string(),
253    })
254}
255
256pub fn make_get_authenticated_user_response(
257    user_id: i32,
258    github_login: String,
259) -> GetAuthenticatedUserResponse {
260    GetAuthenticatedUserResponse {
261        user: AuthenticatedUser {
262            id: user_id,
263            metrics_id: format!("metrics-id-{user_id}"),
264            avatar_url: "".to_string(),
265            github_login,
266            name: None,
267            is_staff: false,
268            accepted_tos_at: None,
269        },
270        feature_flags: vec![],
271        organizations: vec![],
272        plans_by_organization: BTreeMap::new(),
273        plan: PlanInfo {
274            plan: KnownOrUnknown::Known(Plan::ZedPro),
275            subscription_period: None,
276            usage: CurrentUsage {
277                edit_predictions: UsageData {
278                    used: 250,
279                    limit: UsageLimit::Unlimited,
280                },
281            },
282            trial_started_at: None,
283            is_account_too_young: false,
284            has_overdue_invoices: false,
285        },
286    }
287}