peer.rs

  1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use super::Connection;
  3use anyhow::{anyhow, Context, Result};
  4use async_lock::{Mutex, RwLock};
  5use futures::FutureExt as _;
  6use postage::{
  7    mpsc,
  8    prelude::{Sink as _, Stream as _},
  9};
 10use smol_timeout::TimeoutExt as _;
 11use std::sync::atomic::Ordering::SeqCst;
 12use std::{
 13    collections::HashMap,
 14    fmt,
 15    future::Future,
 16    marker::PhantomData,
 17    sync::{
 18        atomic::{self, AtomicU32},
 19        Arc,
 20    },
 21    time::Duration,
 22};
 23
 24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 25pub struct ConnectionId(pub u32);
 26
 27impl fmt::Display for ConnectionId {
 28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 29        self.0.fmt(f)
 30    }
 31}
 32
 33#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 34pub struct PeerId(pub u32);
 35
 36impl fmt::Display for PeerId {
 37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 38        self.0.fmt(f)
 39    }
 40}
 41
 42pub struct Receipt<T> {
 43    pub sender_id: ConnectionId,
 44    pub message_id: u32,
 45    payload_type: PhantomData<T>,
 46}
 47
 48impl<T> Clone for Receipt<T> {
 49    fn clone(&self) -> Self {
 50        Self {
 51            sender_id: self.sender_id,
 52            message_id: self.message_id,
 53            payload_type: PhantomData,
 54        }
 55    }
 56}
 57
 58impl<T> Copy for Receipt<T> {}
 59
 60pub struct TypedEnvelope<T> {
 61    pub sender_id: ConnectionId,
 62    pub original_sender_id: Option<PeerId>,
 63    pub message_id: u32,
 64    pub payload: T,
 65}
 66
 67impl<T> TypedEnvelope<T> {
 68    pub fn original_sender_id(&self) -> Result<PeerId> {
 69        self.original_sender_id
 70            .ok_or_else(|| anyhow!("missing original_sender_id"))
 71    }
 72}
 73
 74impl<T: RequestMessage> TypedEnvelope<T> {
 75    pub fn receipt(&self) -> Receipt<T> {
 76        Receipt {
 77            sender_id: self.sender_id,
 78            message_id: self.message_id,
 79            payload_type: PhantomData,
 80        }
 81    }
 82}
 83
 84pub struct Peer {
 85    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 86    next_connection_id: AtomicU32,
 87}
 88
 89#[derive(Clone)]
 90pub struct ConnectionState {
 91    outgoing_tx: mpsc::Sender<proto::Envelope>,
 92    next_message_id: Arc<AtomicU32>,
 93    response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
 94}
 95
 96const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
 97
 98impl Peer {
 99    pub fn new() -> Arc<Self> {
100        Arc::new(Self {
101            connections: Default::default(),
102            next_connection_id: Default::default(),
103        })
104    }
105
106    pub async fn add_connection(
107        self: &Arc<Self>,
108        connection: Connection,
109    ) -> (
110        ConnectionId,
111        impl Future<Output = anyhow::Result<()>> + Send,
112        mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
113    ) {
114        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
115        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
116        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
117        let connection_state = ConnectionState {
118            outgoing_tx,
119            next_message_id: Default::default(),
120            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
121        };
122        let mut writer = MessageStream::new(connection.tx);
123        let mut reader = MessageStream::new(connection.rx);
124
125        let this = self.clone();
126        let response_channels = connection_state.response_channels.clone();
127        let handle_io = async move {
128            let result = 'outer: loop {
129                let read_message = reader.read_message().fuse();
130                futures::pin_mut!(read_message);
131                loop {
132                    futures::select_biased! {
133                        incoming = read_message => match incoming {
134                            Ok(incoming) => {
135                                if let Some(responding_to) = incoming.responding_to {
136                                    let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to);
137                                    if let Some(mut tx) = channel {
138                                        tx.send(incoming).await.ok();
139                                    } else {
140                                        log::warn!("received RPC response to unknown request {}", responding_to);
141                                    }
142                                } else {
143                                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
144                                        if incoming_tx.send(envelope).await.is_err() {
145                                            break 'outer Ok(())
146                                        }
147                                    } else {
148                                        log::error!("unable to construct a typed envelope");
149                                    }
150                                }
151
152                                break;
153                            }
154                            Err(error) => {
155                                break 'outer Err(error).context("received invalid RPC message")
156                            }
157                        },
158                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
159                            Some(outgoing) => {
160                                match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
161                                    None => break 'outer Err(anyhow!("timed out writing RPC message")),
162                                    Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
163                                    _ => {}
164                                }
165                            }
166                            None => break 'outer Ok(()),
167                        }
168                    }
169                }
170            };
171
172            response_channels.lock().await.take();
173            this.connections.write().await.remove(&connection_id);
174            result
175        };
176
177        self.connections
178            .write()
179            .await
180            .insert(connection_id, connection_state);
181
182        (connection_id, handle_io, incoming_rx)
183    }
184
185    pub async fn disconnect(&self, connection_id: ConnectionId) {
186        self.connections.write().await.remove(&connection_id);
187    }
188
189    pub async fn reset(&self) {
190        self.connections.write().await.clear();
191    }
192
193    pub fn request<T: RequestMessage>(
194        self: &Arc<Self>,
195        receiver_id: ConnectionId,
196        request: T,
197    ) -> impl Future<Output = Result<T::Response>> {
198        self.request_internal(None, receiver_id, request)
199    }
200
201    pub fn forward_request<T: RequestMessage>(
202        self: &Arc<Self>,
203        sender_id: ConnectionId,
204        receiver_id: ConnectionId,
205        request: T,
206    ) -> impl Future<Output = Result<T::Response>> {
207        self.request_internal(Some(sender_id), receiver_id, request)
208    }
209
210    pub fn request_internal<T: RequestMessage>(
211        self: &Arc<Self>,
212        original_sender_id: Option<ConnectionId>,
213        receiver_id: ConnectionId,
214        request: T,
215    ) -> impl Future<Output = Result<T::Response>> {
216        let this = self.clone();
217        let (tx, mut rx) = mpsc::channel(1);
218        async move {
219            let mut connection = this.connection_state(receiver_id).await?;
220            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
221            connection
222                .response_channels
223                .lock()
224                .await
225                .as_mut()
226                .ok_or_else(|| anyhow!("connection was closed"))?
227                .insert(message_id, tx);
228            connection
229                .outgoing_tx
230                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
231                .await
232                .map_err(|_| anyhow!("connection was closed"))?;
233            let response = rx
234                .recv()
235                .await
236                .ok_or_else(|| anyhow!("connection was closed"))?;
237            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
238                Err(anyhow!("request failed").context(error.message.clone()))
239            } else {
240                T::Response::from_envelope(response)
241                    .ok_or_else(|| anyhow!("received response of the wrong type"))
242            }
243        }
244    }
245
246    pub fn send<T: EnvelopedMessage>(
247        self: &Arc<Self>,
248        receiver_id: ConnectionId,
249        message: T,
250    ) -> impl Future<Output = Result<()>> {
251        let this = self.clone();
252        async move {
253            let mut connection = this.connection_state(receiver_id).await?;
254            let message_id = connection
255                .next_message_id
256                .fetch_add(1, atomic::Ordering::SeqCst);
257            connection
258                .outgoing_tx
259                .send(message.into_envelope(message_id, None, None))
260                .await?;
261            Ok(())
262        }
263    }
264
265    pub fn forward_send<T: EnvelopedMessage>(
266        self: &Arc<Self>,
267        sender_id: ConnectionId,
268        receiver_id: ConnectionId,
269        message: T,
270    ) -> impl Future<Output = Result<()>> {
271        let this = self.clone();
272        async move {
273            let mut connection = this.connection_state(receiver_id).await?;
274            let message_id = connection
275                .next_message_id
276                .fetch_add(1, atomic::Ordering::SeqCst);
277            connection
278                .outgoing_tx
279                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
280                .await?;
281            Ok(())
282        }
283    }
284
285    pub fn respond<T: RequestMessage>(
286        self: &Arc<Self>,
287        receipt: Receipt<T>,
288        response: T::Response,
289    ) -> impl Future<Output = Result<()>> {
290        let this = self.clone();
291        async move {
292            let mut connection = this.connection_state(receipt.sender_id).await?;
293            let message_id = connection
294                .next_message_id
295                .fetch_add(1, atomic::Ordering::SeqCst);
296            connection
297                .outgoing_tx
298                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
299                .await?;
300            Ok(())
301        }
302    }
303
304    pub fn respond_with_error<T: RequestMessage>(
305        self: &Arc<Self>,
306        receipt: Receipt<T>,
307        response: proto::Error,
308    ) -> impl Future<Output = Result<()>> {
309        let this = self.clone();
310        async move {
311            let mut connection = this.connection_state(receipt.sender_id).await?;
312            let message_id = connection
313                .next_message_id
314                .fetch_add(1, atomic::Ordering::SeqCst);
315            connection
316                .outgoing_tx
317                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
318                .await?;
319            Ok(())
320        }
321    }
322
323    fn connection_state(
324        self: &Arc<Self>,
325        connection_id: ConnectionId,
326    ) -> impl Future<Output = Result<ConnectionState>> {
327        let this = self.clone();
328        async move {
329            let connections = this.connections.read().await;
330            let connection = connections
331                .get(&connection_id)
332                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
333            Ok(connection.clone())
334        }
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::TypedEnvelope;
342    use async_tungstenite::tungstenite::Message as WebSocketMessage;
343    use futures::StreamExt as _;
344
345    #[test]
346    fn test_request_response() {
347        smol::block_on(async move {
348            // create 2 clients connected to 1 server
349            let server = Peer::new();
350            let client1 = Peer::new();
351            let client2 = Peer::new();
352
353            let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
354            let (client1_conn_id, io_task1, _) =
355                client1.add_connection(client1_to_server_conn).await;
356            let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
357
358            let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
359            let (client2_conn_id, io_task3, _) =
360                client2.add_connection(client2_to_server_conn).await;
361            let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
362
363            smol::spawn(io_task1).detach();
364            smol::spawn(io_task2).detach();
365            smol::spawn(io_task3).detach();
366            smol::spawn(io_task4).detach();
367            smol::spawn(handle_messages(incoming1, server.clone())).detach();
368            smol::spawn(handle_messages(incoming2, server.clone())).detach();
369
370            assert_eq!(
371                client1
372                    .request(client1_conn_id, proto::Ping {},)
373                    .await
374                    .unwrap(),
375                proto::Ack {}
376            );
377
378            assert_eq!(
379                client2
380                    .request(client2_conn_id, proto::Ping {},)
381                    .await
382                    .unwrap(),
383                proto::Ack {}
384            );
385
386            assert_eq!(
387                client1
388                    .request(
389                        client1_conn_id,
390                        proto::OpenBuffer {
391                            project_id: 0,
392                            worktree_id: 1,
393                            path: "path/one".to_string(),
394                        },
395                    )
396                    .await
397                    .unwrap(),
398                proto::OpenBufferResponse {
399                    buffer: Some(proto::Buffer {
400                        id: 101,
401                        content: "path/one content".to_string(),
402                        history: vec![],
403                        selections: vec![],
404                        diagnostic_sets: vec![],
405                    }),
406                }
407            );
408
409            assert_eq!(
410                client2
411                    .request(
412                        client2_conn_id,
413                        proto::OpenBuffer {
414                            project_id: 0,
415                            worktree_id: 2,
416                            path: "path/two".to_string(),
417                        },
418                    )
419                    .await
420                    .unwrap(),
421                proto::OpenBufferResponse {
422                    buffer: Some(proto::Buffer {
423                        id: 102,
424                        content: "path/two content".to_string(),
425                        history: vec![],
426                        selections: vec![],
427                        diagnostic_sets: vec![],
428                    }),
429                }
430            );
431
432            client1.disconnect(client1_conn_id).await;
433            client2.disconnect(client1_conn_id).await;
434
435            async fn handle_messages(
436                mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
437                peer: Arc<Peer>,
438            ) -> Result<()> {
439                while let Some(envelope) = messages.next().await {
440                    let envelope = envelope.into_any();
441                    if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
442                        let receipt = envelope.receipt();
443                        peer.respond(receipt, proto::Ack {}).await?
444                    } else if let Some(envelope) =
445                        envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
446                    {
447                        let message = &envelope.payload;
448                        let receipt = envelope.receipt();
449                        let response = match message.path.as_str() {
450                            "path/one" => {
451                                assert_eq!(message.worktree_id, 1);
452                                proto::OpenBufferResponse {
453                                    buffer: Some(proto::Buffer {
454                                        id: 101,
455                                        content: "path/one content".to_string(),
456                                        history: vec![],
457                                        selections: vec![],
458                                        diagnostic_sets: vec![],
459                                    }),
460                                }
461                            }
462                            "path/two" => {
463                                assert_eq!(message.worktree_id, 2);
464                                proto::OpenBufferResponse {
465                                    buffer: Some(proto::Buffer {
466                                        id: 102,
467                                        content: "path/two content".to_string(),
468                                        history: vec![],
469                                        selections: vec![],
470                                        diagnostic_sets: vec![],
471                                    }),
472                                }
473                            }
474                            _ => {
475                                panic!("unexpected path {}", message.path);
476                            }
477                        };
478
479                        peer.respond(receipt, response).await?
480                    } else {
481                        panic!("unknown message type");
482                    }
483                }
484
485                Ok(())
486            }
487        });
488    }
489
490    #[test]
491    fn test_disconnect() {
492        smol::block_on(async move {
493            let (client_conn, mut server_conn, _) = Connection::in_memory();
494
495            let client = Peer::new();
496            let (connection_id, io_handler, mut incoming) =
497                client.add_connection(client_conn).await;
498
499            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
500            smol::spawn(async move {
501                io_handler.await.ok();
502                io_ended_tx.send(()).await.unwrap();
503            })
504            .detach();
505
506            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
507            smol::spawn(async move {
508                incoming.next().await;
509                messages_ended_tx.send(()).await.unwrap();
510            })
511            .detach();
512
513            client.disconnect(connection_id).await;
514
515            io_ended_rx.recv().await;
516            messages_ended_rx.recv().await;
517            assert!(server_conn
518                .send(WebSocketMessage::Binary(vec![]))
519                .await
520                .is_err());
521        });
522    }
523
524    #[test]
525    fn test_io_error() {
526        smol::block_on(async move {
527            let (client_conn, mut server_conn, _) = Connection::in_memory();
528
529            let client = Peer::new();
530            let (connection_id, io_handler, mut incoming) =
531                client.add_connection(client_conn).await;
532            smol::spawn(io_handler).detach();
533            smol::spawn(async move { incoming.next().await }).detach();
534
535            let response = smol::spawn(client.request(connection_id, proto::Ping {}));
536            let _request = server_conn.rx.next().await.unwrap().unwrap();
537
538            drop(server_conn);
539            assert_eq!(
540                response.await.unwrap_err().to_string(),
541                "connection was closed"
542            );
543        });
544    }
545}