1use super::Client;
2use super::*;
3use crate::http::{HttpClient, Request, Response, ServerResponse};
4use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
5use gpui::{ModelHandle, TestAppContext};
6use parking_lot::Mutex;
7use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
8use std::fmt;
9use std::sync::atomic::Ordering::SeqCst;
10use std::sync::{
11 atomic::{AtomicBool, AtomicUsize},
12 Arc,
13};
14
15pub struct FakeServer {
16 peer: Arc<Peer>,
17 incoming: Mutex<Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>>,
18 connection_id: Mutex<Option<ConnectionId>>,
19 forbid_connections: AtomicBool,
20 auth_count: AtomicUsize,
21 access_token: AtomicUsize,
22 user_id: u64,
23}
24
25impl FakeServer {
26 pub async fn for_client(
27 client_user_id: u64,
28 client: &mut Arc<Client>,
29 cx: &TestAppContext,
30 ) -> Arc<Self> {
31 let server = Arc::new(Self {
32 peer: Peer::new(),
33 incoming: Default::default(),
34 connection_id: Default::default(),
35 forbid_connections: Default::default(),
36 auth_count: Default::default(),
37 access_token: Default::default(),
38 user_id: client_user_id,
39 });
40
41 Arc::get_mut(client)
42 .unwrap()
43 .override_authenticate({
44 let server = server.clone();
45 move |cx| {
46 server.auth_count.fetch_add(1, SeqCst);
47 let access_token = server.access_token.load(SeqCst).to_string();
48 cx.spawn(move |_| async move {
49 Ok(Credentials {
50 user_id: client_user_id,
51 access_token,
52 })
53 })
54 }
55 })
56 .override_establish_connection({
57 let server = server.clone();
58 move |credentials, cx| {
59 let credentials = credentials.clone();
60 cx.spawn({
61 let server = server.clone();
62 move |cx| async move { server.establish_connection(&credentials, &cx).await }
63 })
64 }
65 });
66
67 client
68 .authenticate_and_connect(&cx.to_async())
69 .await
70 .unwrap();
71 server
72 }
73
74 pub fn disconnect(&self) {
75 self.peer.disconnect(self.connection_id());
76 self.connection_id.lock().take();
77 self.incoming.lock().take();
78 }
79
80 async fn establish_connection(
81 &self,
82 credentials: &Credentials,
83 cx: &AsyncAppContext,
84 ) -> Result<Connection, EstablishConnectionError> {
85 assert_eq!(credentials.user_id, self.user_id);
86
87 if self.forbid_connections.load(SeqCst) {
88 Err(EstablishConnectionError::Other(anyhow!(
89 "server is forbidding connections"
90 )))?
91 }
92
93 if credentials.access_token != self.access_token.load(SeqCst).to_string() {
94 Err(EstablishConnectionError::Unauthorized)?
95 }
96
97 let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
98 let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
99 cx.background().spawn(io).detach();
100 *self.incoming.lock() = Some(incoming);
101 *self.connection_id.lock() = Some(connection_id);
102 Ok(client_conn)
103 }
104
105 pub fn auth_count(&self) -> usize {
106 self.auth_count.load(SeqCst)
107 }
108
109 pub fn roll_access_token(&self) {
110 self.access_token.fetch_add(1, SeqCst);
111 }
112
113 pub fn forbid_connections(&self) {
114 self.forbid_connections.store(true, SeqCst);
115 }
116
117 pub fn allow_connections(&self) {
118 self.forbid_connections.store(false, SeqCst);
119 }
120
121 pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
122 self.peer.send(self.connection_id(), message).unwrap();
123 }
124
125 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
126 let message = self
127 .incoming
128 .lock()
129 .as_mut()
130 .expect("not connected")
131 .next()
132 .await
133 .ok_or_else(|| anyhow!("other half hung up"))?;
134 let type_name = message.payload_type_name();
135 Ok(*message
136 .into_any()
137 .downcast::<TypedEnvelope<M>>()
138 .unwrap_or_else(|_| {
139 panic!(
140 "fake server received unexpected message type: {:?}",
141 type_name
142 );
143 }))
144 }
145
146 pub async fn respond<T: proto::RequestMessage>(
147 &self,
148 receipt: Receipt<T>,
149 response: T::Response,
150 ) {
151 self.peer.respond(receipt, response).unwrap()
152 }
153
154 fn connection_id(&self) -> ConnectionId {
155 self.connection_id.lock().expect("not connected")
156 }
157
158 pub async fn build_user_store(
159 &self,
160 client: Arc<Client>,
161 cx: &mut TestAppContext,
162 ) -> ModelHandle<UserStore> {
163 let http_client = FakeHttpClient::with_404_response();
164 let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
165 assert_eq!(
166 self.receive::<proto::GetUsers>()
167 .await
168 .unwrap()
169 .payload
170 .user_ids,
171 &[self.user_id]
172 );
173 user_store
174 }
175}
176
177pub struct FakeHttpClient {
178 handler:
179 Box<dyn 'static + Send + Sync + Fn(Request) -> BoxFuture<'static, Result<ServerResponse>>>,
180}
181
182impl FakeHttpClient {
183 pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
184 where
185 Fut: 'static + Send + Future<Output = Result<ServerResponse>>,
186 F: 'static + Send + Sync + Fn(Request) -> Fut,
187 {
188 Arc::new(Self {
189 handler: Box::new(move |req| Box::pin(handler(req))),
190 })
191 }
192
193 pub fn with_404_response() -> Arc<dyn HttpClient> {
194 Self::new(|_| async move { Ok(ServerResponse::new(404)) })
195 }
196}
197
198impl fmt::Debug for FakeHttpClient {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 f.debug_struct("FakeHttpClient").finish()
201 }
202}
203
204impl HttpClient for FakeHttpClient {
205 fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response>> {
206 let future = (self.handler)(req);
207 Box::pin(async move { future.await.map(Into::into) })
208 }
209}