peer.rs

  1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use super::Connection;
  3use anyhow::{anyhow, Context, Result};
  4use async_lock::{Mutex, RwLock};
  5use futures::FutureExt as _;
  6use postage::{
  7    mpsc,
  8    prelude::{Sink as _, Stream as _},
  9};
 10use smol_timeout::TimeoutExt as _;
 11use std::sync::atomic::Ordering::SeqCst;
 12use std::{
 13    collections::HashMap,
 14    fmt,
 15    future::Future,
 16    marker::PhantomData,
 17    sync::{
 18        atomic::{self, AtomicU32},
 19        Arc,
 20    },
 21    time::Duration,
 22};
 23
 24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 25pub struct ConnectionId(pub u32);
 26
 27impl fmt::Display for ConnectionId {
 28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 29        self.0.fmt(f)
 30    }
 31}
 32
 33#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 34pub struct PeerId(pub u32);
 35
 36impl fmt::Display for PeerId {
 37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 38        self.0.fmt(f)
 39    }
 40}
 41
 42pub struct Receipt<T> {
 43    pub sender_id: ConnectionId,
 44    pub message_id: u32,
 45    payload_type: PhantomData<T>,
 46}
 47
 48impl<T> Clone for Receipt<T> {
 49    fn clone(&self) -> Self {
 50        Self {
 51            sender_id: self.sender_id,
 52            message_id: self.message_id,
 53            payload_type: PhantomData,
 54        }
 55    }
 56}
 57
 58impl<T> Copy for Receipt<T> {}
 59
 60pub struct TypedEnvelope<T> {
 61    pub sender_id: ConnectionId,
 62    pub original_sender_id: Option<PeerId>,
 63    pub message_id: u32,
 64    pub payload: T,
 65}
 66
 67impl<T> TypedEnvelope<T> {
 68    pub fn original_sender_id(&self) -> Result<PeerId> {
 69        self.original_sender_id
 70            .ok_or_else(|| anyhow!("missing original_sender_id"))
 71    }
 72}
 73
 74impl<T: RequestMessage> TypedEnvelope<T> {
 75    pub fn receipt(&self) -> Receipt<T> {
 76        Receipt {
 77            sender_id: self.sender_id,
 78            message_id: self.message_id,
 79            payload_type: PhantomData,
 80        }
 81    }
 82}
 83
 84pub struct Peer {
 85    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 86    next_connection_id: AtomicU32,
 87}
 88
 89#[derive(Clone)]
 90pub struct ConnectionState {
 91    outgoing_tx: mpsc::Sender<proto::Envelope>,
 92    next_message_id: Arc<AtomicU32>,
 93    response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
 94}
 95
 96const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
 97
 98impl Peer {
 99    pub fn new() -> Arc<Self> {
100        Arc::new(Self {
101            connections: Default::default(),
102            next_connection_id: Default::default(),
103        })
104    }
105
106    pub async fn add_connection(
107        self: &Arc<Self>,
108        connection: Connection,
109    ) -> (
110        ConnectionId,
111        impl Future<Output = anyhow::Result<()>> + Send,
112        mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
113    ) {
114        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
115        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
116        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
117        let connection_state = ConnectionState {
118            outgoing_tx,
119            next_message_id: Default::default(),
120            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
121        };
122        let mut writer = MessageStream::new(connection.tx);
123        let mut reader = MessageStream::new(connection.rx);
124
125        let this = self.clone();
126        let response_channels = connection_state.response_channels.clone();
127        let handle_io = async move {
128            let result = 'outer: loop {
129                let read_message = reader.read_message().fuse();
130                futures::pin_mut!(read_message);
131                loop {
132                    futures::select_biased! {
133                        incoming = read_message => match incoming {
134                            Ok(incoming) => {
135                                if let Some(responding_to) = incoming.responding_to {
136                                    let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to);
137                                    if let Some(mut tx) = channel {
138                                        tx.send(incoming).await.ok();
139                                    } else {
140                                        log::warn!("received RPC response to unknown request {}", responding_to);
141                                    }
142                                } else {
143                                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
144                                        if incoming_tx.send(envelope).await.is_err() {
145                                            break 'outer Ok(())
146                                        }
147                                    } else {
148                                        log::error!("unable to construct a typed envelope");
149                                    }
150                                }
151
152                                break;
153                            }
154                            Err(error) => {
155                                break 'outer Err(error).context("received invalid RPC message")
156                            }
157                        },
158                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
159                            Some(outgoing) => {
160                                match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
161                                    None => break 'outer Err(anyhow!("timed out writing RPC message")),
162                                    Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
163                                    _ => {}
164                                }
165                            }
166                            None => break 'outer Ok(()),
167                        }
168                    }
169                }
170            };
171
172            response_channels.lock().await.take();
173            this.connections.write().await.remove(&connection_id);
174            result
175        };
176
177        self.connections
178            .write()
179            .await
180            .insert(connection_id, connection_state);
181
182        (connection_id, handle_io, incoming_rx)
183    }
184
185    pub async fn disconnect(&self, connection_id: ConnectionId) {
186        self.connections.write().await.remove(&connection_id);
187    }
188
189    pub async fn reset(&self) {
190        self.connections.write().await.clear();
191    }
192
193    pub fn request<T: RequestMessage>(
194        self: &Arc<Self>,
195        receiver_id: ConnectionId,
196        request: T,
197    ) -> impl Future<Output = Result<T::Response>> {
198        self.request_internal(None, receiver_id, request)
199    }
200
201    pub fn forward_request<T: RequestMessage>(
202        self: &Arc<Self>,
203        sender_id: ConnectionId,
204        receiver_id: ConnectionId,
205        request: T,
206    ) -> impl Future<Output = Result<T::Response>> {
207        self.request_internal(Some(sender_id), receiver_id, request)
208    }
209
210    pub fn request_internal<T: RequestMessage>(
211        self: &Arc<Self>,
212        original_sender_id: Option<ConnectionId>,
213        receiver_id: ConnectionId,
214        request: T,
215    ) -> impl Future<Output = Result<T::Response>> {
216        let this = self.clone();
217        let (tx, mut rx) = mpsc::channel(1);
218        async move {
219            let mut connection = this.connection_state(receiver_id).await?;
220            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
221            connection
222                .response_channels
223                .lock()
224                .await
225                .as_mut()
226                .ok_or_else(|| anyhow!("connection was closed"))?
227                .insert(message_id, tx);
228            connection
229                .outgoing_tx
230                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
231                .await
232                .map_err(|_| anyhow!("connection was closed"))?;
233            let response = rx
234                .recv()
235                .await
236                .ok_or_else(|| anyhow!("connection was closed"))?;
237            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
238                Err(anyhow!("request failed").context(error.message.clone()))
239            } else {
240                T::Response::from_envelope(response)
241                    .ok_or_else(|| anyhow!("received response of the wrong type"))
242            }
243        }
244    }
245
246    pub fn send<T: EnvelopedMessage>(
247        self: &Arc<Self>,
248        receiver_id: ConnectionId,
249        message: T,
250    ) -> impl Future<Output = Result<()>> {
251        let this = self.clone();
252        async move {
253            let mut connection = this.connection_state(receiver_id).await?;
254            let message_id = connection
255                .next_message_id
256                .fetch_add(1, atomic::Ordering::SeqCst);
257            connection
258                .outgoing_tx
259                .send(message.into_envelope(message_id, None, None))
260                .await?;
261            Ok(())
262        }
263    }
264
265    pub fn forward_send<T: EnvelopedMessage>(
266        self: &Arc<Self>,
267        sender_id: ConnectionId,
268        receiver_id: ConnectionId,
269        message: T,
270    ) -> impl Future<Output = Result<()>> {
271        let this = self.clone();
272        async move {
273            let mut connection = this.connection_state(receiver_id).await?;
274            let message_id = connection
275                .next_message_id
276                .fetch_add(1, atomic::Ordering::SeqCst);
277            connection
278                .outgoing_tx
279                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
280                .await?;
281            Ok(())
282        }
283    }
284
285    pub fn respond<T: RequestMessage>(
286        self: &Arc<Self>,
287        receipt: Receipt<T>,
288        response: T::Response,
289    ) -> impl Future<Output = Result<()>> {
290        let this = self.clone();
291        async move {
292            let mut connection = this.connection_state(receipt.sender_id).await?;
293            let message_id = connection
294                .next_message_id
295                .fetch_add(1, atomic::Ordering::SeqCst);
296            connection
297                .outgoing_tx
298                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
299                .await?;
300            Ok(())
301        }
302    }
303
304    pub fn respond_with_error<T: RequestMessage>(
305        self: &Arc<Self>,
306        receipt: Receipt<T>,
307        response: proto::Error,
308    ) -> impl Future<Output = Result<()>> {
309        let this = self.clone();
310        async move {
311            let mut connection = this.connection_state(receipt.sender_id).await?;
312            let message_id = connection
313                .next_message_id
314                .fetch_add(1, atomic::Ordering::SeqCst);
315            connection
316                .outgoing_tx
317                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
318                .await?;
319            Ok(())
320        }
321    }
322
323    fn connection_state(
324        self: &Arc<Self>,
325        connection_id: ConnectionId,
326    ) -> impl Future<Output = Result<ConnectionState>> {
327        let this = self.clone();
328        async move {
329            let connections = this.connections.read().await;
330            let connection = connections
331                .get(&connection_id)
332                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
333            Ok(connection.clone())
334        }
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::TypedEnvelope;
342    use async_tungstenite::tungstenite::Message as WebSocketMessage;
343    use futures::StreamExt as _;
344
345    #[test]
346    fn test_request_response() {
347        smol::block_on(async move {
348            // create 2 clients connected to 1 server
349            let server = Peer::new();
350            let client1 = Peer::new();
351            let client2 = Peer::new();
352
353            let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
354            let (client1_conn_id, io_task1, _) =
355                client1.add_connection(client1_to_server_conn).await;
356            let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
357
358            let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
359            let (client2_conn_id, io_task3, _) =
360                client2.add_connection(client2_to_server_conn).await;
361            let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
362
363            smol::spawn(io_task1).detach();
364            smol::spawn(io_task2).detach();
365            smol::spawn(io_task3).detach();
366            smol::spawn(io_task4).detach();
367            smol::spawn(handle_messages(incoming1, server.clone())).detach();
368            smol::spawn(handle_messages(incoming2, server.clone())).detach();
369
370            assert_eq!(
371                client1
372                    .request(client1_conn_id, proto::Ping {},)
373                    .await
374                    .unwrap(),
375                proto::Ack {}
376            );
377
378            assert_eq!(
379                client2
380                    .request(client2_conn_id, proto::Ping {},)
381                    .await
382                    .unwrap(),
383                proto::Ack {}
384            );
385
386            assert_eq!(
387                client1
388                    .request(
389                        client1_conn_id,
390                        proto::OpenBuffer {
391                            worktree_id: 1,
392                            path: "path/one".to_string(),
393                        },
394                    )
395                    .await
396                    .unwrap(),
397                proto::OpenBufferResponse {
398                    buffer: Some(proto::Buffer {
399                        id: 101,
400                        content: "path/one content".to_string(),
401                        history: vec![],
402                        selections: vec![],
403                        diagnostics: None,
404                    }),
405                }
406            );
407
408            assert_eq!(
409                client2
410                    .request(
411                        client2_conn_id,
412                        proto::OpenBuffer {
413                            worktree_id: 2,
414                            path: "path/two".to_string(),
415                        },
416                    )
417                    .await
418                    .unwrap(),
419                proto::OpenBufferResponse {
420                    buffer: Some(proto::Buffer {
421                        id: 102,
422                        content: "path/two content".to_string(),
423                        history: vec![],
424                        selections: vec![],
425                        diagnostics: None,
426                    }),
427                }
428            );
429
430            client1.disconnect(client1_conn_id).await;
431            client2.disconnect(client1_conn_id).await;
432
433            async fn handle_messages(
434                mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
435                peer: Arc<Peer>,
436            ) -> Result<()> {
437                while let Some(envelope) = messages.next().await {
438                    let envelope = envelope.into_any();
439                    if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
440                        let receipt = envelope.receipt();
441                        peer.respond(receipt, proto::Ack {}).await?
442                    } else if let Some(envelope) =
443                        envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
444                    {
445                        let message = &envelope.payload;
446                        let receipt = envelope.receipt();
447                        let response = match message.path.as_str() {
448                            "path/one" => {
449                                assert_eq!(message.worktree_id, 1);
450                                proto::OpenBufferResponse {
451                                    buffer: Some(proto::Buffer {
452                                        id: 101,
453                                        content: "path/one content".to_string(),
454                                        history: vec![],
455                                        selections: vec![],
456                                        diagnostics: None,
457                                    }),
458                                }
459                            }
460                            "path/two" => {
461                                assert_eq!(message.worktree_id, 2);
462                                proto::OpenBufferResponse {
463                                    buffer: Some(proto::Buffer {
464                                        id: 102,
465                                        content: "path/two content".to_string(),
466                                        history: vec![],
467                                        selections: vec![],
468                                        diagnostics: None,
469                                    }),
470                                }
471                            }
472                            _ => {
473                                panic!("unexpected path {}", message.path);
474                            }
475                        };
476
477                        peer.respond(receipt, response).await?
478                    } else {
479                        panic!("unknown message type");
480                    }
481                }
482
483                Ok(())
484            }
485        });
486    }
487
488    #[test]
489    fn test_disconnect() {
490        smol::block_on(async move {
491            let (client_conn, mut server_conn, _) = Connection::in_memory();
492
493            let client = Peer::new();
494            let (connection_id, io_handler, mut incoming) =
495                client.add_connection(client_conn).await;
496
497            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
498            smol::spawn(async move {
499                io_handler.await.ok();
500                io_ended_tx.send(()).await.unwrap();
501            })
502            .detach();
503
504            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
505            smol::spawn(async move {
506                incoming.next().await;
507                messages_ended_tx.send(()).await.unwrap();
508            })
509            .detach();
510
511            client.disconnect(connection_id).await;
512
513            io_ended_rx.recv().await;
514            messages_ended_rx.recv().await;
515            assert!(server_conn
516                .send(WebSocketMessage::Binary(vec![]))
517                .await
518                .is_err());
519        });
520    }
521
522    #[test]
523    fn test_io_error() {
524        smol::block_on(async move {
525            let (client_conn, mut server_conn, _) = Connection::in_memory();
526
527            let client = Peer::new();
528            let (connection_id, io_handler, mut incoming) =
529                client.add_connection(client_conn).await;
530            smol::spawn(io_handler).detach();
531            smol::spawn(async move { incoming.next().await }).detach();
532
533            let response = smol::spawn(client.request(connection_id, proto::Ping {}));
534            let _request = server_conn.rx.next().await.unwrap().unwrap();
535
536            drop(server_conn);
537            assert_eq!(
538                response.await.unwrap_err().to_string(),
539                "connection was closed"
540            );
541        });
542    }
543}