test.rs

  1use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
  2use anyhow::{Context as _, Result, anyhow};
  3use cloud_api_client::{AuthenticatedUser, GetAuthenticatedUserResponse, PlanInfo};
  4use cloud_llm_client::{CurrentUsage, Plan, UsageData, UsageLimit};
  5use futures::{StreamExt, stream::BoxStream};
  6use gpui::{AppContext as _, Entity, TestAppContext};
  7use http_client::{AsyncBody, Method, Request, http};
  8use parking_lot::Mutex;
  9use rpc::{ConnectionId, Peer, Receipt, TypedEnvelope, proto};
 10use std::sync::Arc;
 11
 12pub struct FakeServer {
 13    peer: Arc<Peer>,
 14    state: Arc<Mutex<FakeServerState>>,
 15    user_id: u64,
 16}
 17
 18#[derive(Default)]
 19struct FakeServerState {
 20    incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
 21    connection_id: Option<ConnectionId>,
 22    forbid_connections: bool,
 23    auth_count: usize,
 24    access_token: usize,
 25}
 26
 27impl FakeServer {
 28    pub async fn for_client(
 29        client_user_id: u64,
 30        client: &Arc<Client>,
 31        cx: &TestAppContext,
 32    ) -> Self {
 33        let server = Self {
 34            peer: Peer::new(0),
 35            state: Default::default(),
 36            user_id: client_user_id,
 37        };
 38
 39        client.http_client().as_fake().replace_handler({
 40            let state = server.state.clone();
 41            move |old_handler, req| {
 42                let state = state.clone();
 43                let old_handler = old_handler.clone();
 44                async move {
 45                    match (req.method(), req.uri().path()) {
 46                        (&Method::GET, "/client/users/me") => {
 47                            let credentials = parse_authorization_header(&req);
 48                            if credentials
 49                                != Some(Credentials {
 50                                    user_id: client_user_id,
 51                                    access_token: state.lock().access_token.to_string(),
 52                                })
 53                            {
 54                                return Ok(http_client::Response::builder()
 55                                    .status(401)
 56                                    .body("Unauthorized".into())
 57                                    .unwrap());
 58                            }
 59
 60                            Ok(http_client::Response::builder()
 61                                .status(200)
 62                                .body(
 63                                    serde_json::to_string(&make_get_authenticated_user_response(
 64                                        client_user_id as i32,
 65                                        format!("user-{client_user_id}"),
 66                                    ))
 67                                    .unwrap()
 68                                    .into(),
 69                                )
 70                                .unwrap())
 71                        }
 72                        _ => old_handler(req).await,
 73                    }
 74                }
 75            }
 76        });
 77        client
 78            .override_authenticate({
 79                let state = Arc::downgrade(&server.state);
 80                move |cx| {
 81                    let state = state.clone();
 82                    cx.spawn(async move |_| {
 83                        let state = state.upgrade().context("server dropped")?;
 84                        let mut state = state.lock();
 85                        state.auth_count += 1;
 86                        let access_token = state.access_token.to_string();
 87                        Ok(Credentials {
 88                            user_id: client_user_id,
 89                            access_token,
 90                        })
 91                    })
 92                }
 93            })
 94            .override_establish_connection({
 95                let peer = Arc::downgrade(&server.peer);
 96                let state = Arc::downgrade(&server.state);
 97                move |credentials, cx| {
 98                    let peer = peer.clone();
 99                    let state = state.clone();
100                    let credentials = credentials.clone();
101                    cx.spawn(async move |cx| {
102                        let state = state.upgrade().context("server dropped")?;
103                        let peer = peer.upgrade().context("server dropped")?;
104                        if state.lock().forbid_connections {
105                            Err(EstablishConnectionError::Other(anyhow!(
106                                "server is forbidding connections"
107                            )))?
108                        }
109
110                        if credentials
111                            != (Credentials {
112                                user_id: client_user_id,
113                                access_token: state.lock().access_token.to_string(),
114                            })
115                        {
116                            Err(EstablishConnectionError::Unauthorized)?
117                        }
118
119                        let (client_conn, server_conn, _) =
120                            Connection::in_memory(cx.background_executor().clone());
121                        let (connection_id, io, incoming) =
122                            peer.add_test_connection(server_conn, cx.background_executor().clone());
123                        cx.background_spawn(io).detach();
124                        {
125                            let mut state = state.lock();
126                            state.connection_id = Some(connection_id);
127                            state.incoming = Some(incoming);
128                        }
129                        peer.send(
130                            connection_id,
131                            proto::Hello {
132                                peer_id: Some(connection_id.into()),
133                            },
134                        )
135                        .unwrap();
136
137                        Ok(client_conn)
138                    })
139                }
140            });
141
142        client
143            .connect(false, &cx.to_async())
144            .await
145            .into_response()
146            .unwrap();
147
148        server
149    }
150
151    pub fn disconnect(&self) {
152        if self.state.lock().connection_id.is_some() {
153            self.peer.disconnect(self.connection_id());
154            let mut state = self.state.lock();
155            state.connection_id.take();
156            state.incoming.take();
157        }
158    }
159
160    pub fn auth_count(&self) -> usize {
161        self.state.lock().auth_count
162    }
163
164    pub fn roll_access_token(&self) {
165        self.state.lock().access_token += 1;
166    }
167
168    pub fn forbid_connections(&self) {
169        self.state.lock().forbid_connections = true;
170    }
171
172    pub fn allow_connections(&self) {
173        self.state.lock().forbid_connections = false;
174    }
175
176    pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
177        self.peer.send(self.connection_id(), message).unwrap();
178    }
179
180    #[allow(clippy::await_holding_lock)]
181    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
182        let message = self
183            .state
184            .lock()
185            .incoming
186            .as_mut()
187            .expect("not connected")
188            .next()
189            .await
190            .context("other half hung up")?;
191        let type_name = message.payload_type_name();
192        let message = message.into_any();
193
194        if message.is::<TypedEnvelope<M>>() {
195            return Ok(*message.downcast().unwrap());
196        }
197
198        panic!(
199            "fake server received unexpected message type: {:?}",
200            type_name
201        );
202    }
203
204    pub fn respond<T: proto::RequestMessage>(&self, receipt: Receipt<T>, response: T::Response) {
205        self.peer.respond(receipt, response).unwrap()
206    }
207
208    fn connection_id(&self) -> ConnectionId {
209        self.state.lock().connection_id.expect("not connected")
210    }
211
212    pub async fn build_user_store(
213        &self,
214        client: Arc<Client>,
215        cx: &mut TestAppContext,
216    ) -> Entity<UserStore> {
217        let user_store = cx.new(|cx| UserStore::new(client, cx));
218        assert_eq!(
219            self.receive::<proto::GetUsers>()
220                .await
221                .unwrap()
222                .payload
223                .user_ids,
224            &[self.user_id]
225        );
226        user_store
227    }
228}
229
230impl Drop for FakeServer {
231    fn drop(&mut self) {
232        self.disconnect();
233    }
234}
235
236pub fn parse_authorization_header(req: &Request<AsyncBody>) -> Option<Credentials> {
237    let mut auth_header = req
238        .headers()
239        .get(http::header::AUTHORIZATION)?
240        .to_str()
241        .ok()?
242        .split_whitespace();
243    let user_id = auth_header.next()?.parse().ok()?;
244    let access_token = auth_header.next()?;
245    Some(Credentials {
246        user_id,
247        access_token: access_token.to_string(),
248    })
249}
250
251pub fn make_get_authenticated_user_response(
252    user_id: i32,
253    github_login: String,
254) -> GetAuthenticatedUserResponse {
255    GetAuthenticatedUserResponse {
256        user: AuthenticatedUser {
257            id: user_id,
258            metrics_id: format!("metrics-id-{user_id}"),
259            avatar_url: "".to_string(),
260            github_login,
261            name: None,
262            is_staff: false,
263            accepted_tos_at: None,
264        },
265        feature_flags: vec![],
266        plan: PlanInfo {
267            plan: Plan::ZedPro,
268            subscription_period: None,
269            usage: CurrentUsage {
270                model_requests: UsageData {
271                    used: 0,
272                    limit: UsageLimit::Limited(500),
273                },
274                edit_predictions: UsageData {
275                    used: 250,
276                    limit: UsageLimit::Unlimited,
277                },
278            },
279            trial_started_at: None,
280            is_usage_based_billing_enabled: false,
281            is_account_too_young: false,
282            has_overdue_invoices: false,
283        },
284    }
285}