test.rs

  1use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
  2use anyhow::{Context as _, Result, anyhow};
  3use cloud_api_client::{
  4    AuthenticatedUser, GetAuthenticatedUserResponse, KnownOrUnknown, Plan, PlanInfo,
  5};
  6use cloud_llm_client::{CurrentUsage, UsageData, UsageLimit};
  7use futures::{StreamExt, stream::BoxStream};
  8use gpui::{AppContext as _, Entity, TestAppContext};
  9use http_client::{AsyncBody, Method, Request, http};
 10use parking_lot::Mutex;
 11use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto};
 12use std::sync::Arc;
 13
 14pub struct FakeServer {
 15    peer: Arc<Peer>,
 16    state: Arc<Mutex<FakeServerState>>,
 17    user_id: u64,
 18}
 19
 20#[derive(Default)]
 21struct FakeServerState {
 22    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 23    connection_id: Option<ConnectionId>,
 24    forbid_connections: bool,
 25    auth_count: usize,
 26    access_token: usize,
 27}
 28
 29impl FakeServer {
 30    pub async fn for_client(
 31        client_user_id: u64,
 32        client: &Arc<Client>,
 33        cx: &TestAppContext,
 34    ) -> Self {
 35        let server = Self {
 36            peer: Peer::new(0),
 37            state: Default::default(),
 38            user_id: client_user_id,
 39        };
 40
 41        client.http_client().as_fake().replace_handler({
 42            let state = server.state.clone();
 43            move |old_handler, req| {
 44                let state = state.clone();
 45                let old_handler = old_handler.clone();
 46                async move {
 47                    match (req.method(), req.uri().path()) {
 48                        (&Method::GET, "/client/users/me") => {
 49                            let credentials = parse_authorization_header(&req);
 50                            if credentials
 51                                != Some(Credentials {
 52                                    user_id: client_user_id,
 53                                    access_token: state.lock().access_token.to_string(),
 54                                })
 55                            {
 56                                return Ok(http_client::Response::builder()
 57                                    .status(401)
 58                                    .body("Unauthorized".into())
 59                                    .unwrap());
 60                            }
 61
 62                            Ok(http_client::Response::builder()
 63                                .status(200)
 64                                .body(
 65                                    serde_json::to_string(&make_get_authenticated_user_response(
 66                                        client_user_id as i32,
 67                                        format!("user-{client_user_id}"),
 68                                    ))
 69                                    .unwrap()
 70                                    .into(),
 71                                )
 72                                .unwrap())
 73                        }
 74                        _ => old_handler(req).await,
 75                    }
 76                }
 77            }
 78        });
 79        client
 80            .override_authenticate({
 81                let state = Arc::downgrade(&server.state);
 82                move |cx| {
 83                    let state = state.clone();
 84                    cx.spawn(async move |_| {
 85                        let state = state.upgrade().context("server dropped")?;
 86                        let mut state = state.lock();
 87                        state.auth_count += 1;
 88                        let access_token = state.access_token.to_string();
 89                        Ok(Credentials {
 90                            user_id: client_user_id,
 91                            access_token,
 92                        })
 93                    })
 94                }
 95            })
 96            .override_establish_connection({
 97                let peer = Arc::downgrade(&server.peer);
 98                let state = Arc::downgrade(&server.state);
 99                move |credentials, cx| {
100                    let peer = peer.clone();
101                    let state = state.clone();
102                    let credentials = credentials.clone();
103                    cx.spawn(async move |cx| {
104                        let state = state.upgrade().context("server dropped")?;
105                        let peer = peer.upgrade().context("server dropped")?;
106                        if state.lock().forbid_connections {
107                            Err(EstablishConnectionError::Other(anyhow!(
108                                "server is forbidding connections"
109                            )))?
110                        }
111
112                        if credentials
113                            != (Credentials {
114                                user_id: client_user_id,
115                                access_token: state.lock().access_token.to_string(),
116                            })
117                        {
118                            Err(EstablishConnectionError::Unauthorized)?
119                        }
120
121                        let (client_conn, server_conn, _) =
122                            Connection::in_memory(cx.background_executor().clone());
123                        let (connection_id, io, incoming) =
124                            peer.add_test_connection(server_conn, cx.background_executor().clone());
125                        cx.background_spawn(io).detach();
126                        {
127                            let mut state = state.lock();
128                            state.connection_id = Some(connection_id);
129                            state.incoming = Some(incoming);
130                        }
131                        peer.send(
132                            connection_id,
133                            proto::Hello {
134                                peer_id: Some(connection_id.into()),
135                            },
136                        )
137                        .unwrap();
138
139                        Ok(client_conn)
140                    })
141                }
142            });
143
144        client
145            .connect(false, &cx.to_async())
146            .await
147            .into_response()
148            .unwrap();
149
150        server
151    }
152
153    pub fn disconnect(&self) {
154        if self.state.lock().connection_id.is_some() {
155            self.peer.disconnect(self.connection_id());
156            let mut state = self.state.lock();
157            state.connection_id.take();
158            state.incoming.take();
159        }
160    }
161
162    pub fn auth_count(&self) -> usize {
163        self.state.lock().auth_count
164    }
165
166    pub fn roll_access_token(&self) {
167        self.state.lock().access_token += 1;
168    }
169
170    pub fn forbid_connections(&self) {
171        self.state.lock().forbid_connections = true;
172    }
173
174    pub fn allow_connections(&self) {
175        self.state.lock().forbid_connections = false;
176    }
177
178    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
179        self.peer.send(self.connection_id(), message).unwrap();
180    }
181
182    #[allow(clippy::await_holding_lock)]
183    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
184        let message = self
185            .state
186            .lock()
187            .incoming
188            .as_mut()
189            .expect("not connected")
190            .next()
191            .await
192            .context("other half hung up")?;
193        let type_name = message.payload_type_name();
194        let message = message.into_any();
195
196        if message.is::<TypedEnvelope<M>>() {
197            return Ok(*message.downcast().unwrap());
198        }
199
200        panic!(
201            "fake server received unexpected message type: {:?}",
202            type_name
203        );
204    }
205
206    pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
207        self.peer.respond(receipt, response).unwrap()
208    }
209
210    fn connection_id(&self) -> ConnectionId {
211        self.state.lock().connection_id.expect("not connected")
212    }
213
214    pub async fn build_user_store(
215        &self,
216        client: Arc<Client>,
217        cx: &mut TestAppContext,
218    ) -> Entity<UserStore> {
219        let user_store = cx.new(|cx| UserStore::new(client, cx));
220        assert_eq!(
221            self.receive::<proto::GetUsers>()
222                .await
223                .unwrap()
224                .payload
225                .user_ids,
226            &[self.user_id]
227        );
228        user_store
229    }
230}
231
232impl Drop for FakeServer {
233    fn drop(&mut self) {
234        self.disconnect();
235    }
236}
237
238pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
239    let mut auth_header = req
240        .headers()
241        .get(http::header::AUTHORIZATION)?
242        .to_str()
243        .ok()?
244        .split_whitespace();
245    let user_id = auth_header.next()?.parse().ok()?;
246    let access_token = auth_header.next()?;
247    Some(Credentials {
248        user_id,
249        access_token: access_token.to_string(),
250    })
251}
252
253pub fn make_get_authenticated_user_response(
254    user_id: i32,
255    github_login: String,
256) -> GetAuthenticatedUserResponse {
257    GetAuthenticatedUserResponse {
258        user: AuthenticatedUser {
259            id: user_id,
260            metrics_id: format!("metrics-id-{user_id}"),
261            avatar_url: "".to_string(),
262            github_login,
263            name: None,
264            is_staff: false,
265            accepted_tos_at: None,
266        },
267        feature_flags: vec![],
268        plan: PlanInfo {
269            plan: KnownOrUnknown::Known(Plan::ZedPro),
270            subscription_period: None,
271            usage: CurrentUsage {
272                edit_predictions: UsageData {
273                    used: 250,
274                    limit: UsageLimit::Unlimited,
275                },
276            },
277            trial_started_at: None,
278            is_account_too_young: false,
279            has_overdue_invoices: false,
280        },
281    }
282}