1use super::Client;
2use super::*;
3use crate::http::{HttpClient, Request, Response, ServerResponse};
4use futures::{future::BoxFuture, Future};
5use gpui::{ModelHandle, TestAppContext};
6use parking_lot::Mutex;
7use postage::{mpsc, prelude::Stream};
8use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
9use std::fmt;
10use std::sync::atomic::Ordering::SeqCst;
11use std::sync::{
12 atomic::{AtomicBool, AtomicUsize},
13 Arc,
14};
15
16pub struct FakeServer {
17 peer: Arc<Peer>,
18 incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
19 connection_id: Mutex<Option<ConnectionId>>,
20 forbid_connections: AtomicBool,
21 auth_count: AtomicUsize,
22 access_token: AtomicUsize,
23 user_id: u64,
24}
25
26impl FakeServer {
27 pub async fn for_client(
28 client_user_id: u64,
29 client: &mut Arc<Client>,
30 cx: &TestAppContext,
31 ) -> Arc<Self> {
32 let server = Arc::new(Self {
33 peer: Peer::new(),
34 incoming: Default::default(),
35 connection_id: Default::default(),
36 forbid_connections: Default::default(),
37 auth_count: Default::default(),
38 access_token: Default::default(),
39 user_id: client_user_id,
40 });
41
42 Arc::get_mut(client)
43 .unwrap()
44 .override_authenticate({
45 let server = server.clone();
46 move |cx| {
47 server.auth_count.fetch_add(1, SeqCst);
48 let access_token = server.access_token.load(SeqCst).to_string();
49 cx.spawn(move |_| async move {
50 Ok(Credentials {
51 user_id: client_user_id,
52 access_token,
53 })
54 })
55 }
56 })
57 .override_establish_connection({
58 let server = server.clone();
59 move |credentials, cx| {
60 let credentials = credentials.clone();
61 cx.spawn({
62 let server = server.clone();
63 move |cx| async move { server.establish_connection(&credentials, &cx).await }
64 })
65 }
66 });
67
68 client
69 .authenticate_and_connect(&cx.to_async())
70 .await
71 .unwrap();
72 server
73 }
74
75 pub async fn disconnect(&self) {
76 self.peer.disconnect(self.connection_id()).await;
77 self.connection_id.lock().take();
78 self.incoming.lock().take();
79 }
80
81 async fn establish_connection(
82 &self,
83 credentials: &Credentials,
84 cx: &AsyncAppContext,
85 ) -> Result<Connection, EstablishConnectionError> {
86 assert_eq!(credentials.user_id, self.user_id);
87
88 if self.forbid_connections.load(SeqCst) {
89 Err(EstablishConnectionError::Other(anyhow!(
90 "server is forbidding connections"
91 )))?
92 }
93
94 if credentials.access_token != self.access_token.load(SeqCst).to_string() {
95 Err(EstablishConnectionError::Unauthorized)?
96 }
97
98 let (client_conn, server_conn, _) = Connection::in_memory();
99 let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
100 cx.background().spawn(io).detach();
101 *self.incoming.lock() = Some(incoming);
102 *self.connection_id.lock() = Some(connection_id);
103 Ok(client_conn)
104 }
105
106 pub fn auth_count(&self) -> usize {
107 self.auth_count.load(SeqCst)
108 }
109
110 pub fn roll_access_token(&self) {
111 self.access_token.fetch_add(1, SeqCst);
112 }
113
114 pub fn forbid_connections(&self) {
115 self.forbid_connections.store(true, SeqCst);
116 }
117
118 pub fn allow_connections(&self) {
119 self.forbid_connections.store(false, SeqCst);
120 }
121
122 pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
123 self.peer.send(self.connection_id(), message).await.unwrap();
124 }
125
126 pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
127 let message = self
128 .incoming
129 .lock()
130 .as_mut()
131 .expect("not connected")
132 .recv()
133 .await
134 .ok_or_else(|| anyhow!("other half hung up"))?;
135 let type_name = message.payload_type_name();
136 Ok(*message
137 .into_any()
138 .downcast::<TypedEnvelope<M>>()
139 .unwrap_or_else(|_| {
140 panic!(
141 "fake server received unexpected message type: {:?}",
142 type_name
143 );
144 }))
145 }
146
147 pub async fn respond<T: proto::RequestMessage>(
148 &self,
149 receipt: Receipt<T>,
150 response: T::Response,
151 ) {
152 self.peer.respond(receipt, response).await.unwrap()
153 }
154
155 fn connection_id(&self) -> ConnectionId {
156 self.connection_id.lock().expect("not connected")
157 }
158
159 pub async fn build_user_store(
160 &self,
161 client: Arc<Client>,
162 cx: &mut TestAppContext,
163 ) -> ModelHandle<UserStore> {
164 let http_client = FakeHttpClient::with_404_response();
165 let user_store = cx.add_model(|cx| UserStore::new(client, http_client, cx));
166 assert_eq!(
167 self.receive::<proto::GetUsers>()
168 .await
169 .unwrap()
170 .payload
171 .user_ids,
172 &[self.user_id]
173 );
174 user_store
175 }
176}
177
178pub struct FakeHttpClient {
179 handler:
180 Box<dyn 'static + Send + Sync + Fn(Request) -> BoxFuture<'static, Result<ServerResponse>>>,
181}
182
183impl FakeHttpClient {
184 pub fn new<Fut, F>(handler: F) -> Arc<dyn HttpClient>
185 where
186 Fut: 'static + Send + Future<Output = Result<ServerResponse>>,
187 F: 'static + Send + Sync + Fn(Request) -> Fut,
188 {
189 Arc::new(Self {
190 handler: Box::new(move |req| Box::pin(handler(req))),
191 })
192 }
193
194 pub fn with_404_response() -> Arc<dyn HttpClient> {
195 Self::new(|_| async move { Ok(ServerResponse::new(404)) })
196 }
197}
198
199impl fmt::Debug for FakeHttpClient {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 f.debug_struct("FakeHttpClient").finish()
202 }
203}
204
205impl HttpClient for FakeHttpClient {
206 fn send<'a>(&'a self, req: Request) -> BoxFuture<'a, Result<Response>> {
207 let future = (self.handler)(req);
208 Box::pin(async move { future.await.map(Into::into) })
209 }
210}