1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Result};
3use async_lock::{Mutex, RwLock};
4use futures::{
5 future::{BoxFuture, Either},
6 AsyncRead, AsyncWrite, FutureExt,
7};
8use postage::{
9 barrier, mpsc, oneshot,
10 prelude::{Sink, Stream},
11};
12use std::{
13 any::TypeId,
14 collections::{HashMap, HashSet},
15 future::Future,
16 pin::Pin,
17 sync::{
18 atomic::{self, AtomicU32},
19 Arc,
20 },
21};
22
23type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
26pub struct ConnectionId(u32);
27
28struct Connection {
29 writer: Mutex<MessageStream<BoxedWriter>>,
30 response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
31 next_message_id: AtomicU32,
32 _close_barrier: barrier::Sender,
33}
34
35type MessageHandler =
36 Box<dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<()>>>;
37
38pub struct TypedEnvelope<T> {
39 id: u32,
40 connection_id: ConnectionId,
41 payload: T,
42}
43
44impl<T> TypedEnvelope<T> {
45 pub fn connection_id(&self) -> ConnectionId {
46 self.connection_id
47 }
48
49 pub fn payload(&self) -> &T {
50 &self.payload
51 }
52}
53
54pub struct Peer {
55 connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
56 message_handlers: RwLock<Vec<MessageHandler>>,
57 handler_types: Mutex<HashSet<TypeId>>,
58 next_connection_id: AtomicU32,
59}
60
61impl Peer {
62 pub fn new() -> Arc<Self> {
63 Arc::new(Self {
64 connections: Default::default(),
65 message_handlers: Default::default(),
66 handler_types: Default::default(),
67 next_connection_id: Default::default(),
68 })
69 }
70
71 pub async fn add_message_handler<T: EnvelopedMessage>(
72 &self,
73 ) -> mpsc::Receiver<TypedEnvelope<T>> {
74 if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
75 panic!("duplicate handler type");
76 }
77
78 let (tx, rx) = mpsc::channel(256);
79 self.message_handlers
80 .write()
81 .await
82 .push(Box::new(move |envelope, connection_id| {
83 if envelope.as_ref().map_or(false, T::matches_envelope) {
84 let envelope = Option::take(envelope).unwrap();
85 let mut tx = tx.clone();
86 Some(
87 async move {
88 tx.send(TypedEnvelope {
89 id: envelope.id,
90 connection_id,
91 payload: T::from_envelope(envelope).unwrap(),
92 })
93 .await;
94 }
95 .boxed(),
96 )
97 } else {
98 None
99 }
100 }));
101 rx
102 }
103
104 pub async fn add_connection<Conn>(
105 self: &Arc<Self>,
106 conn: Conn,
107 ) -> (ConnectionId, impl Future<Output = ()>)
108 where
109 Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
110 {
111 let connection_id = ConnectionId(
112 self.next_connection_id
113 .fetch_add(1, atomic::Ordering::SeqCst),
114 );
115 let (close_tx, mut close_rx) = barrier::channel();
116 let connection = Arc::new(Connection {
117 writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
118 response_channels: Default::default(),
119 next_message_id: Default::default(),
120 _close_barrier: close_tx,
121 });
122
123 self.connections
124 .write()
125 .await
126 .insert(connection_id, connection.clone());
127
128 let this = self.clone();
129 let handler_future = async move {
130 let closed = close_rx.recv();
131 futures::pin_mut!(closed);
132
133 let mut stream = MessageStream::new(conn);
134 loop {
135 let read_message = stream.read_message();
136 futures::pin_mut!(read_message);
137
138 match futures::future::select(read_message, &mut closed).await {
139 Either::Left((Ok(incoming), _)) => {
140 if let Some(responding_to) = incoming.responding_to {
141 let channel = connection
142 .response_channels
143 .lock()
144 .await
145 .remove(&responding_to);
146 if let Some(mut tx) = channel {
147 tx.send(incoming).await.ok();
148 } else {
149 log::warn!(
150 "received RPC response to unknown request {}",
151 responding_to
152 );
153 }
154 } else {
155 let mut handled = false;
156 let mut envelope = Some(incoming);
157 for handler in this.message_handlers.read().await.iter() {
158 if let Some(future) = handler(&mut envelope, connection_id) {
159 future.await;
160 handled = true;
161 break;
162 }
163 }
164
165 if !handled {
166 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
167 }
168 }
169 }
170 Either::Left((Err(error), _)) => {
171 log::warn!("received invalid RPC message: {}", error);
172 }
173 Either::Right(_) => break,
174 }
175 }
176 };
177
178 (connection_id, handler_future)
179 }
180
181 pub async fn disconnect(&self, connection_id: ConnectionId) {
182 self.connections.write().await.remove(&connection_id);
183 }
184
185 pub fn request<T: RequestMessage>(
186 self: &Arc<Self>,
187 connection_id: ConnectionId,
188 req: T,
189 ) -> impl Future<Output = Result<T::Response>> {
190 let this = self.clone();
191 let (tx, mut rx) = oneshot::channel();
192 async move {
193 let connection = this
194 .connections
195 .read()
196 .await
197 .get(&connection_id)
198 .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
199 .clone();
200 let message_id = connection
201 .next_message_id
202 .fetch_add(1, atomic::Ordering::SeqCst);
203 connection
204 .response_channels
205 .lock()
206 .await
207 .insert(message_id, tx);
208 connection
209 .writer
210 .lock()
211 .await
212 .write_message(&req.into_envelope(message_id, None))
213 .await?;
214 let response = rx
215 .recv()
216 .await
217 .expect("response channel was unexpectedly dropped");
218 T::Response::from_envelope(response)
219 .ok_or_else(|| anyhow!("received response of the wrong type"))
220 }
221 }
222
223 pub fn send<T: EnvelopedMessage>(
224 self: &Arc<Self>,
225 connection_id: ConnectionId,
226 message: T,
227 ) -> impl Future<Output = Result<()>> {
228 let this = self.clone();
229 async move {
230 let connection = this
231 .connections
232 .read()
233 .await
234 .get(&connection_id)
235 .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
236 .clone();
237 let message_id = connection
238 .next_message_id
239 .fetch_add(1, atomic::Ordering::SeqCst);
240 connection
241 .writer
242 .lock()
243 .await
244 .write_message(&message.into_envelope(message_id, None))
245 .await?;
246 Ok(())
247 }
248 }
249
250 pub fn respond<T: RequestMessage>(
251 self: &Arc<Self>,
252 request: TypedEnvelope<T>,
253 response: T::Response,
254 ) -> impl Future<Output = Result<()>> {
255 let this = self.clone();
256 async move {
257 let connection = this
258 .connections
259 .read()
260 .await
261 .get(&request.connection_id)
262 .ok_or_else(|| anyhow!("unknown connection: {}", request.connection_id.0))?
263 .clone();
264 let message_id = connection
265 .next_message_id
266 .fetch_add(1, atomic::Ordering::SeqCst);
267 connection
268 .writer
269 .lock()
270 .await
271 .write_message(&response.into_envelope(message_id, Some(request.id)))
272 .await?;
273 Ok(())
274 }
275 }
276}
277
278// #[cfg(test)]
279// mod tests {
280// use super::*;
281// use smol::{
282// future::poll_once,
283// io::AsyncWriteExt,
284// net::unix::{UnixListener, UnixStream},
285// };
286// use std::{future::Future, io};
287// use tempdir::TempDir;
288
289// #[gpui::test]
290// async fn test_request_response(cx: gpui::TestAppContext) {
291// let executor = cx.read(|app| app.background_executor().clone());
292// let socket_dir_path = TempDir::new("request-response").unwrap();
293// let socket_path = socket_dir_path.path().join(".sock");
294// let listener = UnixListener::bind(&socket_path).unwrap();
295// let client_conn = UnixStream::connect(&socket_path).await.unwrap();
296// let (server_conn, _) = listener.accept().await.unwrap();
297
298// let mut server_stream = MessageStream::new(server_conn);
299// let client = Peer::new();
300// let (connection_id, handler) = client.add_connection(client_conn).await;
301// executor.spawn(handler).detach();
302
303// let client_req = client.request(
304// connection_id,
305// proto::Auth {
306// user_id: 42,
307// access_token: "token".to_string(),
308// },
309// );
310// smol::pin!(client_req);
311// let server_req = send_recv(&mut client_req, server_stream.read_message())
312// .await
313// .unwrap();
314// assert_eq!(
315// server_req.payload,
316// Some(proto::envelope::Payload::Auth(proto::Auth {
317// user_id: 42,
318// access_token: "token".to_string()
319// }))
320// );
321
322// // Respond to another request to ensure requests are properly matched up.
323// server_stream
324// .write_message(
325// &proto::AuthResponse {
326// credentials_valid: false,
327// }
328// .into_envelope(1000, Some(999)),
329// )
330// .await
331// .unwrap();
332// server_stream
333// .write_message(
334// &proto::AuthResponse {
335// credentials_valid: true,
336// }
337// .into_envelope(1001, Some(server_req.id)),
338// )
339// .await
340// .unwrap();
341// assert_eq!(
342// client_req.await.unwrap(),
343// proto::AuthResponse {
344// credentials_valid: true
345// }
346// );
347// }
348
349// #[gpui::test]
350// async fn test_disconnect(cx: gpui::TestAppContext) {
351// let executor = cx.read(|app| app.background_executor().clone());
352// let socket_dir_path = TempDir::new("drop-client").unwrap();
353// let socket_path = socket_dir_path.path().join(".sock");
354// let listener = UnixListener::bind(&socket_path).unwrap();
355// let client_conn = UnixStream::connect(&socket_path).await.unwrap();
356// let (mut server_conn, _) = listener.accept().await.unwrap();
357
358// let client = Peer::new();
359// let (connection_id, handler) = client.add_connection(client_conn).await;
360// executor.spawn(handler).detach();
361// client.disconnect(connection_id).await;
362
363// // Try sending an empty payload over and over, until the client is dropped and hangs up.
364// loop {
365// match server_conn.write(&[]).await {
366// Ok(_) => {}
367// Err(err) => {
368// if err.kind() == io::ErrorKind::BrokenPipe {
369// break;
370// }
371// }
372// }
373// }
374// }
375
376// #[gpui::test]
377// async fn test_io_error(cx: gpui::TestAppContext) {
378// let executor = cx.read(|app| app.background_executor().clone());
379// let socket_dir_path = TempDir::new("io-error").unwrap();
380// let socket_path = socket_dir_path.path().join(".sock");
381// let _listener = UnixListener::bind(&socket_path).unwrap();
382// let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
383// client_conn.close().await.unwrap();
384
385// let client = Peer::new();
386// let (connection_id, handler) = client.add_connection(client_conn).await;
387// executor.spawn(handler).detach();
388// let err = client
389// .request(
390// connection_id,
391// proto::Auth {
392// user_id: 42,
393// access_token: "token".to_string(),
394// },
395// )
396// .await
397// .unwrap_err();
398// assert_eq!(
399// err.downcast_ref::<io::Error>().unwrap().kind(),
400// io::ErrorKind::BrokenPipe
401// );
402// }
403
404// async fn send_recv<S, R, O>(mut sender: S, receiver: R) -> O
405// where
406// S: Unpin + Future,
407// R: Future<Output = O>,
408// {
409// smol::pin!(receiver);
410// loop {
411// poll_once(&mut sender).await;
412// match poll_once(&mut receiver).await {
413// Some(message) => break message,
414// None => continue,
415// }
416// }
417// }
418// }