1use crate::{
2 http::{HttpClient, Request, Response, ServerResponse},
3 Client, Connection, Credentials, EstablishConnectionError, UserStore,
4};
5use anyhow::{anyhow, Result};
6use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
7use gpui::{executor, ModelHandle, TestAppContext};
8use parking_lot::Mutex;
9use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
10use std::{fmt, rc::Rc, sync::Arc};
11
12pub struct FakeServer {
13 peer: Arc<Peer>,
14 state: Arc<Mutex<FakeServerState>>,
15 user_id: u64,
16 executor: Rc<executor::Foreground>,
17}
18
19#[derive(Default)]
20struct FakeServerState {
21 incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
22 connection_id: Option<ConnectionId>,
23 forbid_connections: bool,
24 auth_count: usize,
25 access_token: usize,
26}
27
28impl FakeServer {
29 pub async fn for_client(
30 client_user_id: u64,
31 client: &mut Arc<Client>,
32 cx: &TestAppContext,
33 ) -> Self {
34 let server = Self {
35 peer: Peer::new(),
36 state: Default::default(),
37 user_id: client_user_id,
38 executor: cx.foreground(),
39 };
40
41 Arc::get_mut(client)
42 .unwrap()
43 .override_authenticate({
44 let state = server.state.clone();
45 move |cx| {
46 let mut state = state.lock();
47 state.auth_count += 1;
48 let access_token = state.access_token.to_string();
49 cx.spawn(move |_| async move {
50 Ok(Credentials {
51 user_id: client_user_id,
52 access_token,
53 })
54 })
55 }
56 })
57 .override_establish_connection({
58 let peer = server.peer.clone();
59 let state = server.state.clone();
60 move |credentials, cx| {
61 let peer = peer.clone();
62 let state = state.clone();
63 let credentials = credentials.clone();
64 cx.spawn(move |cx| async move {
65 assert_eq!(credentials.user_id, client_user_id);
66
67 if state.lock().forbid_connections {
68 Err(EstablishConnectionError::Other(anyhow!(
69 "server is forbidding connections"
70 )))?
71 }
72
73 if credentials.access_token != state.lock().access_token.to_string() {
74 Err(EstablishConnectionError::Unauthorized)?
75 }
76
77 let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
78 let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
79 cx.background().spawn(io).detach();
80 let mut state = state.lock();
81 state.connection_id = Some(connection_id);
82 state.incoming = Some(incoming);
83 Ok(client_conn)
84 })
85 }
86 });
87
88 client
89 .authenticate_and_connect(&cx.to_async())
90 .await
91 .unwrap();
92 server
93 }
94
95 pub fn disconnect(&self) {
96 self.peer.disconnect(self.connection_id());
97 let mut state = self.state.lock();
98 state.connection_id.take();
99 state.incoming.take();
100 }
101
102 pub fn auth_count(&self) -> usize {
103 self.state.lock().auth_count
104 }
105
106 pub fn roll_access_token(&self) {
107 self.state.lock().access_token += 1;
108 }
109
110 pub fn forbid_connections(&self) {
111 self.state.lock().forbid_connections = true;
112 }
113
114 pub fn allow_connections(&self) {
115 self.state.lock().forbid_connections = false;
116 }
117
118 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
119 self.peer.send(self.connection_id(), message).unwrap();
120 }
121
122 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
123 self.executor.start_waiting();
124 let message = self
125 .state
126 .lock()
127 .incoming
128 .as_mut()
129 .expect("not connected")
130 .next()
131 .await
132 .ok_or_else(|| anyhow!("other half hung up"))?;
133 self.executor.finish_waiting();
134 let type_name = message.payload_type_name();
135 Ok(*message
136 .into_any()
137 .downcast::<TypedEnvelope<M>>()
138 .unwrap_or_else(|_| {
139 panic!(
140 "fake server received unexpected message type: {:?}",
141 type_name
142 );
143 }))
144 }
145
146 pub async fn respond<T: proto::RequestMessage>(
147 &self,
148 receipt: Receipt<T>,
149 response: T::Response,
150 ) {
151 self.peer.respond(receipt, response).unwrap()
152 }
153
154 fn connection_id(&self) -> ConnectionId {
155 self.state.lock().connection_id.expect("not connected")
156 }
157
158 pub async fn build_user_store(
159 &self,
160 client: Arc<Client>,
161 cx: &mut TestAppContext,
162 ) -> ModelHandle<UserStore> {
163 let http_client = FakeHttpClient::with_404_response();
164 let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
165 assert_eq!(
166 self.receive::<proto::GetUsers>()
167 .await
168 .unwrap()
169 .payload
170 .user_ids,
171 &[self.user_id]
172 );
173 user_store
174 }
175}
176
177pub struct FakeHttpClient {
178 handler:
179 Box<dyn 'static + Send + Sync + Fn(Request) -> BoxFuture<'static, Result<ServerResponse>>>,
180}
181
182impl FakeHttpClient {
183 pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
184 where
185 Fut: 'static + Send + Future<Output = Result<ServerResponse>>,
186 F: 'static + Send + Sync + Fn(Request) -> Fut,
187 {
188 Arc::new(Self {
189 handler: Box::new(move |req| Box::pin(handler(req))),
190 })
191 }
192
193 pub fn with_404_response() -> Arc<dyn HttpClient> {
194 Self::new(|_| async move { Ok(ServerResponse::new(404)) })
195 }
196}
197
198impl fmt::Debug for FakeHttpClient {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 f.debug_struct("FakeHttpClient").finish()
201 }
202}
203
204impl HttpClient for FakeHttpClient {
205 fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response>> {
206 let future = (self.handler)(req);
207 Box::pin(async move { future.await.map(Into::into) })
208 }
209}