peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Result};
  3use async_lock::{Mutex, RwLock};
  4use futures::{
  5    future::{BoxFuture, Either},
  6    AsyncRead, AsyncWrite, FutureExt,
  7};
  8use postage::{
  9    barrier, mpsc, oneshot,
 10    prelude::{Sink, Stream},
 11};
 12use std::{
 13    any::TypeId,
 14    collections::{HashMap, HashSet},
 15    future::Future,
 16    pin::Pin,
 17    sync::{
 18        atomic::{self, AtomicU32},
 19        Arc,
 20    },
 21};
 22
 23type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 26pub struct ConnectionId(u32);
 27
 28struct Connection {
 29    writer: Mutex<MessageStream<BoxedWriter>>,
 30    response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
 31    next_message_id: AtomicU32,
 32    _close_barrier: barrier::Sender,
 33}
 34
 35type MessageHandler =
 36    Box<dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<()>>>;
 37
 38pub struct TypedEnvelope<T> {
 39    id: u32,
 40    connection_id: ConnectionId,
 41    payload: T,
 42}
 43
 44impl<T> TypedEnvelope<T> {
 45    pub fn connection_id(&self) -> ConnectionId {
 46        self.connection_id
 47    }
 48
 49    pub fn payload(&self) -> &T {
 50        &self.payload
 51    }
 52}
 53
 54pub struct Peer {
 55    connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
 56    message_handlers: RwLock<Vec<MessageHandler>>,
 57    handler_types: Mutex<HashSet<TypeId>>,
 58    next_connection_id: AtomicU32,
 59}
 60
 61impl Peer {
 62    pub fn new() -> Arc<Self> {
 63        Arc::new(Self {
 64            connections: Default::default(),
 65            message_handlers: Default::default(),
 66            handler_types: Default::default(),
 67            next_connection_id: Default::default(),
 68        })
 69    }
 70
 71    pub async fn add_message_handler<T: EnvelopedMessage>(
 72        &self,
 73    ) -> mpsc::Receiver<TypedEnvelope<T>> {
 74        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
 75            panic!("duplicate handler type");
 76        }
 77
 78        let (tx, rx) = mpsc::channel(256);
 79        self.message_handlers
 80            .write()
 81            .await
 82            .push(Box::new(move |envelope, connection_id| {
 83                if envelope.as_ref().map_or(false, T::matches_envelope) {
 84                    let envelope = Option::take(envelope).unwrap();
 85                    let mut tx = tx.clone();
 86                    Some(
 87                        async move {
 88                            tx.send(TypedEnvelope {
 89                                id: envelope.id,
 90                                connection_id,
 91                                payload: T::from_envelope(envelope).unwrap(),
 92                            })
 93                            .await;
 94                        }
 95                        .boxed(),
 96                    )
 97                } else {
 98                    None
 99                }
100            }));
101        rx
102    }
103
104    pub async fn add_connection<Conn>(
105        self: &Arc<Self>,
106        conn: Conn,
107    ) -> (ConnectionId, impl Future<Output = ()>)
108    where
109        Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
110    {
111        let connection_id = ConnectionId(
112            self.next_connection_id
113                .fetch_add(1, atomic::Ordering::SeqCst),
114        );
115        let (close_tx, mut close_rx) = barrier::channel();
116        let connection = Arc::new(Connection {
117            writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
118            response_channels: Default::default(),
119            next_message_id: Default::default(),
120            _close_barrier: close_tx,
121        });
122
123        self.connections
124            .write()
125            .await
126            .insert(connection_id, connection.clone());
127
128        let this = self.clone();
129        let handler_future = async move {
130            let closed = close_rx.recv();
131            futures::pin_mut!(closed);
132
133            let mut stream = MessageStream::new(conn);
134            loop {
135                let read_message = stream.read_message();
136                futures::pin_mut!(read_message);
137
138                match futures::future::select(read_message, &mut closed).await {
139                    Either::Left((Ok(incoming), _)) => {
140                        if let Some(responding_to) = incoming.responding_to {
141                            let channel = connection
142                                .response_channels
143                                .lock()
144                                .await
145                                .remove(&responding_to);
146                            if let Some(mut tx) = channel {
147                                tx.send(incoming).await.ok();
148                            } else {
149                                log::warn!(
150                                    "received RPC response to unknown request {}",
151                                    responding_to
152                                );
153                            }
154                        } else {
155                            let mut handled = false;
156                            let mut envelope = Some(incoming);
157                            for handler in this.message_handlers.read().await.iter() {
158                                if let Some(future) = handler(&mut envelope, connection_id) {
159                                    future.await;
160                                    handled = true;
161                                    break;
162                                }
163                            }
164
165                            if !handled {
166                                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
167                            }
168                        }
169                    }
170                    Either::Left((Err(error), _)) => {
171                        log::warn!("received invalid RPC message: {}", error);
172                    }
173                    Either::Right(_) => break,
174                }
175            }
176        };
177
178        (connection_id, handler_future)
179    }
180
181    pub async fn disconnect(&self, connection_id: ConnectionId) {
182        self.connections.write().await.remove(&connection_id);
183    }
184
185    pub fn request<T: RequestMessage>(
186        self: &Arc<Self>,
187        connection_id: ConnectionId,
188        req: T,
189    ) -> impl Future<Output = Result<T::Response>> {
190        let this = self.clone();
191        let (tx, mut rx) = oneshot::channel();
192        async move {
193            let connection = this
194                .connections
195                .read()
196                .await
197                .get(&connection_id)
198                .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
199                .clone();
200            let message_id = connection
201                .next_message_id
202                .fetch_add(1, atomic::Ordering::SeqCst);
203            connection
204                .response_channels
205                .lock()
206                .await
207                .insert(message_id, tx);
208            connection
209                .writer
210                .lock()
211                .await
212                .write_message(&req.into_envelope(message_id, None))
213                .await?;
214            let response = rx
215                .recv()
216                .await
217                .expect("response channel was unexpectedly dropped");
218            T::Response::from_envelope(response)
219                .ok_or_else(|| anyhow!("received response of the wrong type"))
220        }
221    }
222
223    pub fn send<T: EnvelopedMessage>(
224        self: &Arc<Self>,
225        connection_id: ConnectionId,
226        message: T,
227    ) -> impl Future<Output = Result<()>> {
228        let this = self.clone();
229        async move {
230            let connection = this
231                .connections
232                .read()
233                .await
234                .get(&connection_id)
235                .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
236                .clone();
237            let message_id = connection
238                .next_message_id
239                .fetch_add(1, atomic::Ordering::SeqCst);
240            connection
241                .writer
242                .lock()
243                .await
244                .write_message(&message.into_envelope(message_id, None))
245                .await?;
246            Ok(())
247        }
248    }
249
250    pub fn respond<T: RequestMessage>(
251        self: &Arc<Self>,
252        request: TypedEnvelope<T>,
253        response: T::Response,
254    ) -> impl Future<Output = Result<()>> {
255        let this = self.clone();
256        async move {
257            let connection = this
258                .connections
259                .read()
260                .await
261                .get(&request.connection_id)
262                .ok_or_else(|| anyhow!("unknown connection: {}", request.connection_id.0))?
263                .clone();
264            let message_id = connection
265                .next_message_id
266                .fetch_add(1, atomic::Ordering::SeqCst);
267            connection
268                .writer
269                .lock()
270                .await
271                .write_message(&response.into_envelope(message_id, Some(request.id)))
272                .await?;
273            Ok(())
274        }
275    }
276}
277
278// #[cfg(test)]
279// mod tests {
280//     use super::*;
281//     use smol::{
282//         future::poll_once,
283//         io::AsyncWriteExt,
284//         net::unix::{UnixListener, UnixStream},
285//     };
286//     use std::{future::Future, io};
287//     use tempdir::TempDir;
288
289//     #[gpui::test]
290//     async fn test_request_response(cx: gpui::TestAppContext) {
291//         let executor = cx.read(|app| app.background_executor().clone());
292//         let socket_dir_path = TempDir::new("request-response").unwrap();
293//         let socket_path = socket_dir_path.path().join(".sock");
294//         let listener = UnixListener::bind(&socket_path).unwrap();
295//         let client_conn = UnixStream::connect(&socket_path).await.unwrap();
296//         let (server_conn, _) = listener.accept().await.unwrap();
297
298//         let mut server_stream = MessageStream::new(server_conn);
299//         let client = Peer::new();
300//         let (connection_id, handler) = client.add_connection(client_conn).await;
301//         executor.spawn(handler).detach();
302
303//         let client_req = client.request(
304//             connection_id,
305//             proto::Auth {
306//                 user_id: 42,
307//                 access_token: "token".to_string(),
308//             },
309//         );
310//         smol::pin!(client_req);
311//         let server_req = send_recv(&mut client_req, server_stream.read_message())
312//             .await
313//             .unwrap();
314//         assert_eq!(
315//             server_req.payload,
316//             Some(proto::envelope::Payload::Auth(proto::Auth {
317//                 user_id: 42,
318//                 access_token: "token".to_string()
319//             }))
320//         );
321
322//         // Respond to another request to ensure requests are properly matched up.
323//         server_stream
324//             .write_message(
325//                 &proto::AuthResponse {
326//                     credentials_valid: false,
327//                 }
328//                 .into_envelope(1000, Some(999)),
329//             )
330//             .await
331//             .unwrap();
332//         server_stream
333//             .write_message(
334//                 &proto::AuthResponse {
335//                     credentials_valid: true,
336//                 }
337//                 .into_envelope(1001, Some(server_req.id)),
338//             )
339//             .await
340//             .unwrap();
341//         assert_eq!(
342//             client_req.await.unwrap(),
343//             proto::AuthResponse {
344//                 credentials_valid: true
345//             }
346//         );
347//     }
348
349//     #[gpui::test]
350//     async fn test_disconnect(cx: gpui::TestAppContext) {
351//         let executor = cx.read(|app| app.background_executor().clone());
352//         let socket_dir_path = TempDir::new("drop-client").unwrap();
353//         let socket_path = socket_dir_path.path().join(".sock");
354//         let listener = UnixListener::bind(&socket_path).unwrap();
355//         let client_conn = UnixStream::connect(&socket_path).await.unwrap();
356//         let (mut server_conn, _) = listener.accept().await.unwrap();
357
358//         let client = Peer::new();
359//         let (connection_id, handler) = client.add_connection(client_conn).await;
360//         executor.spawn(handler).detach();
361//         client.disconnect(connection_id).await;
362
363//         // Try sending an empty payload over and over, until the client is dropped and hangs up.
364//         loop {
365//             match server_conn.write(&[]).await {
366//                 Ok(_) => {}
367//                 Err(err) => {
368//                     if err.kind() == io::ErrorKind::BrokenPipe {
369//                         break;
370//                     }
371//                 }
372//             }
373//         }
374//     }
375
376//     #[gpui::test]
377//     async fn test_io_error(cx: gpui::TestAppContext) {
378//         let executor = cx.read(|app| app.background_executor().clone());
379//         let socket_dir_path = TempDir::new("io-error").unwrap();
380//         let socket_path = socket_dir_path.path().join(".sock");
381//         let _listener = UnixListener::bind(&socket_path).unwrap();
382//         let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
383//         client_conn.close().await.unwrap();
384
385//         let client = Peer::new();
386//         let (connection_id, handler) = client.add_connection(client_conn).await;
387//         executor.spawn(handler).detach();
388//         let err = client
389//             .request(
390//                 connection_id,
391//                 proto::Auth {
392//                     user_id: 42,
393//                     access_token: "token".to_string(),
394//                 },
395//             )
396//             .await
397//             .unwrap_err();
398//         assert_eq!(
399//             err.downcast_ref::<io::Error>().unwrap().kind(),
400//             io::ErrorKind::BrokenPipe
401//         );
402//     }
403
404//     async fn send_recv<S, R, O>(mut sender: S, receiver: R) -> O
405//     where
406//         S: Unpin + Future,
407//         R: Future<Output = O>,
408//     {
409//         smol::pin!(receiver);
410//         loop {
411//             poll_once(&mut sender).await;
412//             match poll_once(&mut receiver).await {
413//                 Some(message) => break message,
414//                 None => continue,
415//             }
416//         }
417//     }
418// }