peer.rs

  1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use super::Connection;
  3use anyhow::{anyhow, Context, Result};
  4use futures::stream::BoxStream;
  5use futures::{FutureExt as _, StreamExt};
  6use parking_lot::{Mutex, RwLock};
  7use postage::{
  8    mpsc,
  9    prelude::{Sink as _, Stream as _},
 10};
 11use smol_timeout::TimeoutExt as _;
 12use std::sync::atomic::Ordering::SeqCst;
 13use std::{
 14    collections::HashMap,
 15    fmt,
 16    future::Future,
 17    marker::PhantomData,
 18    sync::{
 19        atomic::{self, AtomicU32},
 20        Arc,
 21    },
 22    time::Duration,
 23};
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 26pub struct ConnectionId(pub u32);
 27
 28impl fmt::Display for ConnectionId {
 29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 30        self.0.fmt(f)
 31    }
 32}
 33
 34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 35pub struct PeerId(pub u32);
 36
 37impl fmt::Display for PeerId {
 38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 39        self.0.fmt(f)
 40    }
 41}
 42
 43pub struct Receipt<T> {
 44    pub sender_id: ConnectionId,
 45    pub message_id: u32,
 46    payload_type: PhantomData<T>,
 47}
 48
 49impl<T> Clone for Receipt<T> {
 50    fn clone(&self) -> Self {
 51        Self {
 52            sender_id: self.sender_id,
 53            message_id: self.message_id,
 54            payload_type: PhantomData,
 55        }
 56    }
 57}
 58
 59impl<T> Copy for Receipt<T> {}
 60
 61pub struct TypedEnvelope<T> {
 62    pub sender_id: ConnectionId,
 63    pub original_sender_id: Option<PeerId>,
 64    pub message_id: u32,
 65    pub payload: T,
 66}
 67
 68impl<T> TypedEnvelope<T> {
 69    pub fn original_sender_id(&self) -> Result<PeerId> {
 70        self.original_sender_id
 71            .ok_or_else(|| anyhow!("missing original_sender_id"))
 72    }
 73}
 74
 75impl<T: RequestMessage> TypedEnvelope<T> {
 76    pub fn receipt(&self) -> Receipt<T> {
 77        Receipt {
 78            sender_id: self.sender_id,
 79            message_id: self.message_id,
 80            payload_type: PhantomData,
 81        }
 82    }
 83}
 84
 85pub struct Peer {
 86    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 87    next_connection_id: AtomicU32,
 88}
 89
 90#[derive(Clone)]
 91pub struct ConnectionState {
 92    outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
 93    next_message_id: Arc<AtomicU32>,
 94    response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
 95}
 96
 97const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
 98
 99impl Peer {
100    pub fn new() -> Arc<Self> {
101        Arc::new(Self {
102            connections: Default::default(),
103            next_connection_id: Default::default(),
104        })
105    }
106
107    pub async fn add_connection(
108        self: &Arc<Self>,
109        connection: Connection,
110    ) -> (
111        ConnectionId,
112        impl Future<Output = anyhow::Result<()>> + Send,
113        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
114    ) {
115        // For outgoing messages, use an unbounded channel so that application code
116        // can always send messages without yielding. For incoming messages, use a
117        // bounded channel so that other peers will receive backpressure if they send
118        // messages faster than this peer can process them.
119        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
120        let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
121
122        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
123        let connection_state = ConnectionState {
124            outgoing_tx,
125            next_message_id: Default::default(),
126            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
127        };
128        let mut writer = MessageStream::new(connection.tx);
129        let mut reader = MessageStream::new(connection.rx);
130
131        let this = self.clone();
132        let response_channels = connection_state.response_channels.clone();
133        let handle_io = async move {
134            let result = 'outer: loop {
135                let read_message = reader.read_message().fuse();
136                futures::pin_mut!(read_message);
137                loop {
138                    futures::select_biased! {
139                        outgoing = outgoing_rx.next().fuse() => match outgoing {
140                            Some(outgoing) => {
141                                match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
142                                    None => break 'outer Err(anyhow!("timed out writing RPC message")),
143                                    Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
144                                    _ => {}
145                                }
146                            }
147                            None => break 'outer Ok(()),
148                        },
149                        incoming = read_message => match incoming {
150                            Ok(incoming) => {
151                                if incoming_tx.send(incoming).await.is_err() {
152                                    break 'outer Ok(());
153                                }
154                                break;
155                            }
156                            Err(error) => {
157                                break 'outer Err(error).context("received invalid RPC message")
158                            }
159                        },
160                    }
161                }
162            };
163
164            response_channels.lock().take();
165            this.connections.write().remove(&connection_id);
166            result
167        };
168
169        let response_channels = connection_state.response_channels.clone();
170        self.connections
171            .write()
172            .insert(connection_id, connection_state);
173
174        let incoming_rx = incoming_rx.filter_map(move |incoming| {
175            let response_channels = response_channels.clone();
176            async move {
177                if let Some(responding_to) = incoming.responding_to {
178                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
179                    if let Some(mut tx) = channel {
180                        if let Err(error) = tx.send(incoming).await {
181                            log::debug!(
182                                "received RPC but request future was dropped {:?}",
183                                error.0
184                            );
185                        }
186                    } else {
187                        log::warn!("received RPC response to unknown request {}", responding_to);
188                    }
189
190                    None
191                } else {
192                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
193                        Some(envelope)
194                    } else {
195                        log::error!("unable to construct a typed envelope");
196                        None
197                    }
198                }
199            }
200        });
201        (connection_id, handle_io, incoming_rx.boxed())
202    }
203
204    pub fn disconnect(&self, connection_id: ConnectionId) {
205        self.connections.write().remove(&connection_id);
206    }
207
208    pub fn reset(&self) {
209        self.connections.write().clear();
210    }
211
212    pub fn request<T: RequestMessage>(
213        &self,
214        receiver_id: ConnectionId,
215        request: T,
216    ) -> impl Future<Output = Result<T::Response>> {
217        self.request_internal(None, receiver_id, request)
218    }
219
220    pub fn forward_request<T: RequestMessage>(
221        &self,
222        sender_id: ConnectionId,
223        receiver_id: ConnectionId,
224        request: T,
225    ) -> impl Future<Output = Result<T::Response>> {
226        self.request_internal(Some(sender_id), receiver_id, request)
227    }
228
229    pub fn request_internal<T: RequestMessage>(
230        &self,
231        original_sender_id: Option<ConnectionId>,
232        receiver_id: ConnectionId,
233        request: T,
234    ) -> impl Future<Output = Result<T::Response>> {
235        let (tx, mut rx) = mpsc::channel(1);
236        let send = self.connection_state(receiver_id).and_then(|connection| {
237            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
238            connection
239                .response_channels
240                .lock()
241                .as_mut()
242                .ok_or_else(|| anyhow!("connection was closed"))?
243                .insert(message_id, tx);
244            connection
245                .outgoing_tx
246                .unbounded_send(request.into_envelope(
247                    message_id,
248                    None,
249                    original_sender_id.map(|id| id.0),
250                ))
251                .map_err(|_| anyhow!("connection was closed"))?;
252            Ok(())
253        });
254        async move {
255            send?;
256            let response = rx
257                .recv()
258                .await
259                .ok_or_else(|| anyhow!("connection was closed"))?;
260            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
261                Err(anyhow!("request failed").context(error.message.clone()))
262            } else {
263                T::Response::from_envelope(response)
264                    .ok_or_else(|| anyhow!("received response of the wrong type"))
265            }
266        }
267    }
268
269    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
270        let connection = self.connection_state(receiver_id)?;
271        let message_id = connection
272            .next_message_id
273            .fetch_add(1, atomic::Ordering::SeqCst);
274        connection
275            .outgoing_tx
276            .unbounded_send(message.into_envelope(message_id, None, None))?;
277        Ok(())
278    }
279
280    pub fn forward_send<T: EnvelopedMessage>(
281        &self,
282        sender_id: ConnectionId,
283        receiver_id: ConnectionId,
284        message: T,
285    ) -> Result<()> {
286        let connection = self.connection_state(receiver_id)?;
287        let message_id = connection
288            .next_message_id
289            .fetch_add(1, atomic::Ordering::SeqCst);
290        connection
291            .outgoing_tx
292            .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
293        Ok(())
294    }
295
296    pub fn respond<T: RequestMessage>(
297        &self,
298        receipt: Receipt<T>,
299        response: T::Response,
300    ) -> Result<()> {
301        let connection = self.connection_state(receipt.sender_id)?;
302        let message_id = connection
303            .next_message_id
304            .fetch_add(1, atomic::Ordering::SeqCst);
305        connection
306            .outgoing_tx
307            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
308        Ok(())
309    }
310
311    pub fn respond_with_error<T: RequestMessage>(
312        &self,
313        receipt: Receipt<T>,
314        response: proto::Error,
315    ) -> Result<()> {
316        let connection = self.connection_state(receipt.sender_id)?;
317        let message_id = connection
318            .next_message_id
319            .fetch_add(1, atomic::Ordering::SeqCst);
320        connection
321            .outgoing_tx
322            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
323        Ok(())
324    }
325
326    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
327        let connections = self.connections.read();
328        let connection = connections
329            .get(&connection_id)
330            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
331        Ok(connection.clone())
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use crate::TypedEnvelope;
339    use async_tungstenite::tungstenite::Message as WebSocketMessage;
340    use gpui::TestAppContext;
341
342    #[gpui::test(iterations = 50)]
343    async fn test_request_response(cx: TestAppContext) {
344        let executor = cx.foreground();
345
346        // create 2 clients connected to 1 server
347        let server = Peer::new();
348        let client1 = Peer::new();
349        let client2 = Peer::new();
350
351        let (client1_to_server_conn, server_to_client_1_conn, _) =
352            Connection::in_memory(cx.background());
353        let (client1_conn_id, io_task1, client1_incoming) =
354            client1.add_connection(client1_to_server_conn).await;
355        let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
356
357        let (client2_to_server_conn, server_to_client_2_conn, _) =
358            Connection::in_memory(cx.background());
359        let (client2_conn_id, io_task3, client2_incoming) =
360            client2.add_connection(client2_to_server_conn).await;
361        let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
362
363        executor.spawn(io_task1).detach();
364        executor.spawn(io_task2).detach();
365        executor.spawn(io_task3).detach();
366        executor.spawn(io_task4).detach();
367        executor
368            .spawn(handle_messages(server_incoming1, server.clone()))
369            .detach();
370        executor
371            .spawn(handle_messages(client1_incoming, client1.clone()))
372            .detach();
373        executor
374            .spawn(handle_messages(server_incoming2, server.clone()))
375            .detach();
376        executor
377            .spawn(handle_messages(client2_incoming, client2.clone()))
378            .detach();
379
380        assert_eq!(
381            client1
382                .request(client1_conn_id, proto::Ping {},)
383                .await
384                .unwrap(),
385            proto::Ack {}
386        );
387
388        assert_eq!(
389            client2
390                .request(client2_conn_id, proto::Ping {},)
391                .await
392                .unwrap(),
393            proto::Ack {}
394        );
395
396        assert_eq!(
397            client1
398                .request(
399                    client1_conn_id,
400                    proto::OpenBuffer {
401                        project_id: 0,
402                        worktree_id: 1,
403                        path: "path/one".to_string(),
404                    },
405                )
406                .await
407                .unwrap(),
408            proto::OpenBufferResponse {
409                buffer: Some(proto::Buffer {
410                    variant: Some(proto::buffer::Variant::Id(0))
411                }),
412            }
413        );
414
415        assert_eq!(
416            client2
417                .request(
418                    client2_conn_id,
419                    proto::OpenBuffer {
420                        project_id: 0,
421                        worktree_id: 2,
422                        path: "path/two".to_string(),
423                    },
424                )
425                .await
426                .unwrap(),
427            proto::OpenBufferResponse {
428                buffer: Some(proto::Buffer {
429                    variant: Some(proto::buffer::Variant::Id(1))
430                })
431            }
432        );
433
434        client1.disconnect(client1_conn_id);
435        client2.disconnect(client1_conn_id);
436
437        async fn handle_messages(
438            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
439            peer: Arc<Peer>,
440        ) -> Result<()> {
441            while let Some(envelope) = messages.next().await {
442                let envelope = envelope.into_any();
443                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
444                    let receipt = envelope.receipt();
445                    peer.respond(receipt, proto::Ack {})?
446                } else if let Some(envelope) =
447                    envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
448                {
449                    let message = &envelope.payload;
450                    let receipt = envelope.receipt();
451                    let response = match message.path.as_str() {
452                        "path/one" => {
453                            assert_eq!(message.worktree_id, 1);
454                            proto::OpenBufferResponse {
455                                buffer: Some(proto::Buffer {
456                                    variant: Some(proto::buffer::Variant::Id(0)),
457                                }),
458                            }
459                        }
460                        "path/two" => {
461                            assert_eq!(message.worktree_id, 2);
462                            proto::OpenBufferResponse {
463                                buffer: Some(proto::Buffer {
464                                    variant: Some(proto::buffer::Variant::Id(1)),
465                                }),
466                            }
467                        }
468                        _ => {
469                            panic!("unexpected path {}", message.path);
470                        }
471                    };
472
473                    peer.respond(receipt, response)?
474                } else {
475                    panic!("unknown message type");
476                }
477            }
478
479            Ok(())
480        }
481    }
482
483    #[gpui::test(iterations = 50)]
484    async fn test_order_of_response_and_incoming(cx: TestAppContext) {
485        let executor = cx.foreground();
486        let server = Peer::new();
487        let client = Peer::new();
488
489        let (client_to_server_conn, server_to_client_conn, _) =
490            Connection::in_memory(cx.background());
491        let (client_to_server_conn_id, io_task1, mut client_incoming) =
492            client.add_connection(client_to_server_conn).await;
493        let (server_to_client_conn_id, io_task2, mut server_incoming) =
494            server.add_connection(server_to_client_conn).await;
495
496        executor.spawn(io_task1).detach();
497        executor.spawn(io_task2).detach();
498
499        executor
500            .spawn(async move {
501                let request = server_incoming
502                    .next()
503                    .await
504                    .unwrap()
505                    .into_any()
506                    .downcast::<TypedEnvelope<proto::Ping>>()
507                    .unwrap();
508
509                server
510                    .send(
511                        server_to_client_conn_id,
512                        proto::Error {
513                            message: "message 1".to_string(),
514                        },
515                    )
516                    .unwrap();
517                server
518                    .send(
519                        server_to_client_conn_id,
520                        proto::Error {
521                            message: "message 2".to_string(),
522                        },
523                    )
524                    .unwrap();
525                server.respond(request.receipt(), proto::Ack {}).unwrap();
526
527                // Prevent the connection from being dropped
528                server_incoming.next().await;
529            })
530            .detach();
531
532        let events = Arc::new(Mutex::new(Vec::new()));
533
534        let response = client.request(client_to_server_conn_id, proto::Ping {});
535        let response_task = executor.spawn({
536            let events = events.clone();
537            async move {
538                response.await.unwrap();
539                events.lock().push("response".to_string());
540            }
541        });
542
543        executor
544            .spawn({
545                let events = events.clone();
546                async move {
547                    let incoming1 = client_incoming
548                        .next()
549                        .await
550                        .unwrap()
551                        .into_any()
552                        .downcast::<TypedEnvelope<proto::Error>>()
553                        .unwrap();
554                    events.lock().push(incoming1.payload.message);
555                    let incoming2 = client_incoming
556                        .next()
557                        .await
558                        .unwrap()
559                        .into_any()
560                        .downcast::<TypedEnvelope<proto::Error>>()
561                        .unwrap();
562                    events.lock().push(incoming2.payload.message);
563
564                    // Prevent the connection from being dropped
565                    client_incoming.next().await;
566                }
567            })
568            .detach();
569
570        response_task.await;
571        assert_eq!(
572            &*events.lock(),
573            &[
574                "message 1".to_string(),
575                "message 2".to_string(),
576                "response".to_string()
577            ]
578        );
579    }
580
581    #[gpui::test(iterations = 50)]
582    async fn test_dropping_request_before_completion(cx: TestAppContext) {
583        let executor = cx.foreground();
584        let server = Peer::new();
585        let client = Peer::new();
586
587        let (client_to_server_conn, server_to_client_conn, _) =
588            Connection::in_memory(cx.background());
589        let (client_to_server_conn_id, io_task1, mut client_incoming) =
590            client.add_connection(client_to_server_conn).await;
591        let (server_to_client_conn_id, io_task2, mut server_incoming) =
592            server.add_connection(server_to_client_conn).await;
593
594        executor.spawn(io_task1).detach();
595        executor.spawn(io_task2).detach();
596
597        executor
598            .spawn(async move {
599                let request1 = server_incoming
600                    .next()
601                    .await
602                    .unwrap()
603                    .into_any()
604                    .downcast::<TypedEnvelope<proto::Ping>>()
605                    .unwrap();
606                let request2 = server_incoming
607                    .next()
608                    .await
609                    .unwrap()
610                    .into_any()
611                    .downcast::<TypedEnvelope<proto::Ping>>()
612                    .unwrap();
613
614                server
615                    .send(
616                        server_to_client_conn_id,
617                        proto::Error {
618                            message: "message 1".to_string(),
619                        },
620                    )
621                    .unwrap();
622                server
623                    .send(
624                        server_to_client_conn_id,
625                        proto::Error {
626                            message: "message 2".to_string(),
627                        },
628                    )
629                    .unwrap();
630                server.respond(request1.receipt(), proto::Ack {}).unwrap();
631                server.respond(request2.receipt(), proto::Ack {}).unwrap();
632
633                // Prevent the connection from being dropped
634                server_incoming.next().await;
635            })
636            .detach();
637
638        let events = Arc::new(Mutex::new(Vec::new()));
639
640        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
641        let request1_task = executor.spawn(request1);
642        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
643        let request2_task = executor.spawn({
644            let events = events.clone();
645            async move {
646                request2.await.unwrap();
647                events.lock().push("response 2".to_string());
648            }
649        });
650
651        executor
652            .spawn({
653                let events = events.clone();
654                async move {
655                    let incoming1 = client_incoming
656                        .next()
657                        .await
658                        .unwrap()
659                        .into_any()
660                        .downcast::<TypedEnvelope<proto::Error>>()
661                        .unwrap();
662                    events.lock().push(incoming1.payload.message);
663                    let incoming2 = client_incoming
664                        .next()
665                        .await
666                        .unwrap()
667                        .into_any()
668                        .downcast::<TypedEnvelope<proto::Error>>()
669                        .unwrap();
670                    events.lock().push(incoming2.payload.message);
671
672                    // Prevent the connection from being dropped
673                    client_incoming.next().await;
674                }
675            })
676            .detach();
677
678        // Allow the request to make some progress before dropping it.
679        cx.background().simulate_random_delay().await;
680        drop(request1_task);
681
682        request2_task.await;
683        assert_eq!(
684            &*events.lock(),
685            &[
686                "message 1".to_string(),
687                "message 2".to_string(),
688                "response 2".to_string()
689            ]
690        );
691    }
692
693    #[gpui::test(iterations = 50)]
694    async fn test_disconnect(cx: TestAppContext) {
695        let executor = cx.foreground();
696
697        let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
698
699        let client = Peer::new();
700        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
701
702        let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
703        executor
704            .spawn(async move {
705                io_handler.await.ok();
706                io_ended_tx.send(()).await.unwrap();
707            })
708            .detach();
709
710        let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
711        executor
712            .spawn(async move {
713                incoming.next().await;
714                messages_ended_tx.send(()).await.unwrap();
715            })
716            .detach();
717
718        client.disconnect(connection_id);
719
720        io_ended_rx.recv().await;
721        messages_ended_rx.recv().await;
722        assert!(server_conn
723            .send(WebSocketMessage::Binary(vec![]))
724            .await
725            .is_err());
726    }
727
728    #[gpui::test(iterations = 50)]
729    async fn test_io_error(cx: TestAppContext) {
730        let executor = cx.foreground();
731        let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
732
733        let client = Peer::new();
734        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
735        executor.spawn(io_handler).detach();
736        executor
737            .spawn(async move { incoming.next().await })
738            .detach();
739
740        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
741        let _request = server_conn.rx.next().await.unwrap().unwrap();
742
743        drop(server_conn);
744        assert_eq!(
745            response.await.unwrap_err().to_string(),
746            "connection was closed"
747        );
748    }
749}