peer.rs

  1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use super::Connection;
  3use anyhow::{anyhow, Context, Result};
  4use futures::stream::BoxStream;
  5use futures::{FutureExt as _, StreamExt};
  6use parking_lot::{Mutex, RwLock};
  7use postage::{
  8    barrier, mpsc,
  9    prelude::{Sink as _, Stream as _},
 10};
 11use smol_timeout::TimeoutExt as _;
 12use std::sync::atomic::Ordering::SeqCst;
 13use std::{
 14    collections::HashMap,
 15    fmt,
 16    future::Future,
 17    marker::PhantomData,
 18    sync::{
 19        atomic::{self, AtomicU32},
 20        Arc,
 21    },
 22    time::Duration,
 23};
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 26pub struct ConnectionId(pub u32);
 27
 28impl fmt::Display for ConnectionId {
 29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 30        self.0.fmt(f)
 31    }
 32}
 33
 34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 35pub struct PeerId(pub u32);
 36
 37impl fmt::Display for PeerId {
 38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 39        self.0.fmt(f)
 40    }
 41}
 42
 43pub struct Receipt<T> {
 44    pub sender_id: ConnectionId,
 45    pub message_id: u32,
 46    payload_type: PhantomData<T>,
 47}
 48
 49impl<T> Clone for Receipt<T> {
 50    fn clone(&self) -> Self {
 51        Self {
 52            sender_id: self.sender_id,
 53            message_id: self.message_id,
 54            payload_type: PhantomData,
 55        }
 56    }
 57}
 58
 59impl<T> Copy for Receipt<T> {}
 60
 61pub struct TypedEnvelope<T> {
 62    pub sender_id: ConnectionId,
 63    pub original_sender_id: Option<PeerId>,
 64    pub message_id: u32,
 65    pub payload: T,
 66}
 67
 68impl<T> TypedEnvelope<T> {
 69    pub fn original_sender_id(&self) -> Result<PeerId> {
 70        self.original_sender_id
 71            .ok_or_else(|| anyhow!("missing original_sender_id"))
 72    }
 73}
 74
 75impl<T: RequestMessage> TypedEnvelope<T> {
 76    pub fn receipt(&self) -> Receipt<T> {
 77        Receipt {
 78            sender_id: self.sender_id,
 79            message_id: self.message_id,
 80            payload_type: PhantomData,
 81        }
 82    }
 83}
 84
 85pub struct Peer {
 86    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 87    next_connection_id: AtomicU32,
 88}
 89
 90#[derive(Clone)]
 91pub struct ConnectionState {
 92    outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
 93    next_message_id: Arc<AtomicU32>,
 94    response_channels:
 95        Arc<Mutex<Option<HashMap<u32, mpsc::Sender<(proto::Envelope, barrier::Sender)>>>>>,
 96}
 97
 98const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
 99
100impl Peer {
101    pub fn new() -> Arc<Self> {
102        Arc::new(Self {
103            connections: Default::default(),
104            next_connection_id: Default::default(),
105        })
106    }
107
108    pub async fn add_connection(
109        self: &Arc<Self>,
110        connection: Connection,
111    ) -> (
112        ConnectionId,
113        impl Future<Output = anyhow::Result<()>> + Send,
114        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
115    ) {
116        // For outgoing messages, use an unbounded channel so that application code
117        // can always send messages without yielding. For incoming messages, use a
118        // bounded channel so that other peers will receive backpressure if they send
119        // messages faster than this peer can process them.
120        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
121        let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
122
123        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
124        let connection_state = ConnectionState {
125            outgoing_tx,
126            next_message_id: Default::default(),
127            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
128        };
129        let mut writer = MessageStream::new(connection.tx);
130        let mut reader = MessageStream::new(connection.rx);
131
132        let this = self.clone();
133        let response_channels = connection_state.response_channels.clone();
134        let handle_io = async move {
135            let result = 'outer: loop {
136                let read_message = reader.read_message().fuse();
137                futures::pin_mut!(read_message);
138                loop {
139                    futures::select_biased! {
140                        outgoing = outgoing_rx.next().fuse() => match outgoing {
141                            Some(outgoing) => {
142                                match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
143                                    None => break 'outer Err(anyhow!("timed out writing RPC message")),
144                                    Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
145                                    _ => {}
146                                }
147                            }
148                            None => break 'outer Ok(()),
149                        },
150                        incoming = read_message => match incoming {
151                            Ok(incoming) => {
152                                if incoming_tx.send(incoming).await.is_err() {
153                                    break 'outer Ok(());
154                                }
155                                break;
156                            }
157                            Err(error) => {
158                                break 'outer Err(error).context("received invalid RPC message")
159                            }
160                        },
161                    }
162                }
163            };
164
165            response_channels.lock().take();
166            this.connections.write().remove(&connection_id);
167            result
168        };
169
170        let response_channels = connection_state.response_channels.clone();
171        self.connections
172            .write()
173            .insert(connection_id, connection_state);
174
175        let incoming_rx = incoming_rx.filter_map(move |incoming| {
176            let response_channels = response_channels.clone();
177            async move {
178                if let Some(responding_to) = incoming.responding_to {
179                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
180                    if let Some(mut tx) = channel {
181                        let mut requester_resumed = barrier::channel();
182                        if let Err(error) = tx.send((incoming, requester_resumed.0)).await {
183                            log::debug!(
184                                "received RPC but request future was dropped {:?}",
185                                error.0 .0
186                            );
187                        }
188                        // Drop response channel before awaiting on the barrier. This allows the
189                        // barrier to get dropped even if the request's future is dropped before it
190                        // has a chance to observe the response.
191                        drop(tx);
192                        requester_resumed.1.recv().await;
193                    } else {
194                        log::warn!("received RPC response to unknown request {}", responding_to);
195                    }
196
197                    None
198                } else {
199                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
200                        Some(envelope)
201                    } else {
202                        log::error!("unable to construct a typed envelope");
203                        None
204                    }
205                }
206            }
207        });
208        (connection_id, handle_io, incoming_rx.boxed())
209    }
210
211    pub fn disconnect(&self, connection_id: ConnectionId) {
212        self.connections.write().remove(&connection_id);
213    }
214
215    pub fn reset(&self) {
216        self.connections.write().clear();
217    }
218
219    pub fn request<T: RequestMessage>(
220        &self,
221        receiver_id: ConnectionId,
222        request: T,
223    ) -> impl Future<Output = Result<T::Response>> {
224        self.request_internal(None, receiver_id, request)
225    }
226
227    pub fn forward_request<T: RequestMessage>(
228        &self,
229        sender_id: ConnectionId,
230        receiver_id: ConnectionId,
231        request: T,
232    ) -> impl Future<Output = Result<T::Response>> {
233        self.request_internal(Some(sender_id), receiver_id, request)
234    }
235
236    pub fn request_internal<T: RequestMessage>(
237        &self,
238        original_sender_id: Option<ConnectionId>,
239        receiver_id: ConnectionId,
240        request: T,
241    ) -> impl Future<Output = Result<T::Response>> {
242        let (tx, mut rx) = mpsc::channel(1);
243        let send = self.connection_state(receiver_id).and_then(|connection| {
244            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
245            connection
246                .response_channels
247                .lock()
248                .as_mut()
249                .ok_or_else(|| anyhow!("connection was closed"))?
250                .insert(message_id, tx);
251            connection
252                .outgoing_tx
253                .unbounded_send(request.into_envelope(
254                    message_id,
255                    None,
256                    original_sender_id.map(|id| id.0),
257                ))
258                .map_err(|_| anyhow!("connection was closed"))?;
259            Ok(())
260        });
261        async move {
262            send?;
263            let (response, _barrier) = rx
264                .recv()
265                .await
266                .ok_or_else(|| anyhow!("connection was closed"))?;
267            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
268                Err(anyhow!("request failed").context(error.message.clone()))
269            } else {
270                T::Response::from_envelope(response)
271                    .ok_or_else(|| anyhow!("received response of the wrong type"))
272            }
273        }
274    }
275
276    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
277        let connection = self.connection_state(receiver_id)?;
278        let message_id = connection
279            .next_message_id
280            .fetch_add(1, atomic::Ordering::SeqCst);
281        connection
282            .outgoing_tx
283            .unbounded_send(message.into_envelope(message_id, None, None))?;
284        Ok(())
285    }
286
287    pub fn forward_send<T: EnvelopedMessage>(
288        &self,
289        sender_id: ConnectionId,
290        receiver_id: ConnectionId,
291        message: T,
292    ) -> Result<()> {
293        let connection = self.connection_state(receiver_id)?;
294        let message_id = connection
295            .next_message_id
296            .fetch_add(1, atomic::Ordering::SeqCst);
297        connection
298            .outgoing_tx
299            .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
300        Ok(())
301    }
302
303    pub fn respond<T: RequestMessage>(
304        &self,
305        receipt: Receipt<T>,
306        response: T::Response,
307    ) -> Result<()> {
308        let connection = self.connection_state(receipt.sender_id)?;
309        let message_id = connection
310            .next_message_id
311            .fetch_add(1, atomic::Ordering::SeqCst);
312        connection
313            .outgoing_tx
314            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
315        Ok(())
316    }
317
318    pub fn respond_with_error<T: RequestMessage>(
319        &self,
320        receipt: Receipt<T>,
321        response: proto::Error,
322    ) -> Result<()> {
323        let connection = self.connection_state(receipt.sender_id)?;
324        let message_id = connection
325            .next_message_id
326            .fetch_add(1, atomic::Ordering::SeqCst);
327        connection
328            .outgoing_tx
329            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
330        Ok(())
331    }
332
333    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
334        let connections = self.connections.read();
335        let connection = connections
336            .get(&connection_id)
337            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
338        Ok(connection.clone())
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use crate::TypedEnvelope;
346    use async_tungstenite::tungstenite::Message as WebSocketMessage;
347    use gpui::TestAppContext;
348
349    #[gpui::test(iterations = 50)]
350    async fn test_request_response(cx: TestAppContext) {
351        let executor = cx.foreground();
352
353        // create 2 clients connected to 1 server
354        let server = Peer::new();
355        let client1 = Peer::new();
356        let client2 = Peer::new();
357
358        let (client1_to_server_conn, server_to_client_1_conn, _) =
359            Connection::in_memory(cx.background());
360        let (client1_conn_id, io_task1, client1_incoming) =
361            client1.add_connection(client1_to_server_conn).await;
362        let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
363
364        let (client2_to_server_conn, server_to_client_2_conn, _) =
365            Connection::in_memory(cx.background());
366        let (client2_conn_id, io_task3, client2_incoming) =
367            client2.add_connection(client2_to_server_conn).await;
368        let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
369
370        executor.spawn(io_task1).detach();
371        executor.spawn(io_task2).detach();
372        executor.spawn(io_task3).detach();
373        executor.spawn(io_task4).detach();
374        executor
375            .spawn(handle_messages(server_incoming1, server.clone()))
376            .detach();
377        executor
378            .spawn(handle_messages(client1_incoming, client1.clone()))
379            .detach();
380        executor
381            .spawn(handle_messages(server_incoming2, server.clone()))
382            .detach();
383        executor
384            .spawn(handle_messages(client2_incoming, client2.clone()))
385            .detach();
386
387        assert_eq!(
388            client1
389                .request(client1_conn_id, proto::Ping {},)
390                .await
391                .unwrap(),
392            proto::Ack {}
393        );
394
395        assert_eq!(
396            client2
397                .request(client2_conn_id, proto::Ping {},)
398                .await
399                .unwrap(),
400            proto::Ack {}
401        );
402
403        assert_eq!(
404            client1
405                .request(
406                    client1_conn_id,
407                    proto::OpenBuffer {
408                        project_id: 0,
409                        worktree_id: 1,
410                        path: "path/one".to_string(),
411                    },
412                )
413                .await
414                .unwrap(),
415            proto::OpenBufferResponse {
416                buffer: Some(proto::Buffer {
417                    variant: Some(proto::buffer::Variant::Id(0))
418                }),
419            }
420        );
421
422        assert_eq!(
423            client2
424                .request(
425                    client2_conn_id,
426                    proto::OpenBuffer {
427                        project_id: 0,
428                        worktree_id: 2,
429                        path: "path/two".to_string(),
430                    },
431                )
432                .await
433                .unwrap(),
434            proto::OpenBufferResponse {
435                buffer: Some(proto::Buffer {
436                    variant: Some(proto::buffer::Variant::Id(1))
437                })
438            }
439        );
440
441        client1.disconnect(client1_conn_id);
442        client2.disconnect(client1_conn_id);
443
444        async fn handle_messages(
445            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
446            peer: Arc<Peer>,
447        ) -> Result<()> {
448            while let Some(envelope) = messages.next().await {
449                let envelope = envelope.into_any();
450                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
451                    let receipt = envelope.receipt();
452                    peer.respond(receipt, proto::Ack {})?
453                } else if let Some(envelope) =
454                    envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
455                {
456                    let message = &envelope.payload;
457                    let receipt = envelope.receipt();
458                    let response = match message.path.as_str() {
459                        "path/one" => {
460                            assert_eq!(message.worktree_id, 1);
461                            proto::OpenBufferResponse {
462                                buffer: Some(proto::Buffer {
463                                    variant: Some(proto::buffer::Variant::Id(0)),
464                                }),
465                            }
466                        }
467                        "path/two" => {
468                            assert_eq!(message.worktree_id, 2);
469                            proto::OpenBufferResponse {
470                                buffer: Some(proto::Buffer {
471                                    variant: Some(proto::buffer::Variant::Id(1)),
472                                }),
473                            }
474                        }
475                        _ => {
476                            panic!("unexpected path {}", message.path);
477                        }
478                    };
479
480                    peer.respond(receipt, response)?
481                } else {
482                    panic!("unknown message type");
483                }
484            }
485
486            Ok(())
487        }
488    }
489
490    #[gpui::test(iterations = 50)]
491    async fn test_order_of_response_and_incoming(cx: TestAppContext) {
492        let executor = cx.foreground();
493        let server = Peer::new();
494        let client = Peer::new();
495
496        let (client_to_server_conn, server_to_client_conn, _) =
497            Connection::in_memory(cx.background());
498        let (client_to_server_conn_id, io_task1, mut client_incoming) =
499            client.add_connection(client_to_server_conn).await;
500        let (server_to_client_conn_id, io_task2, mut server_incoming) =
501            server.add_connection(server_to_client_conn).await;
502
503        executor.spawn(io_task1).detach();
504        executor.spawn(io_task2).detach();
505
506        executor
507            .spawn(async move {
508                let request = server_incoming
509                    .next()
510                    .await
511                    .unwrap()
512                    .into_any()
513                    .downcast::<TypedEnvelope<proto::Ping>>()
514                    .unwrap();
515
516                server
517                    .send(
518                        server_to_client_conn_id,
519                        proto::Error {
520                            message: "message 1".to_string(),
521                        },
522                    )
523                    .unwrap();
524                server
525                    .send(
526                        server_to_client_conn_id,
527                        proto::Error {
528                            message: "message 2".to_string(),
529                        },
530                    )
531                    .unwrap();
532                server.respond(request.receipt(), proto::Ack {}).unwrap();
533
534                // Prevent the connection from being dropped
535                server_incoming.next().await;
536            })
537            .detach();
538
539        let events = Arc::new(Mutex::new(Vec::new()));
540
541        let response = client.request(client_to_server_conn_id, proto::Ping {});
542        let response_task = executor.spawn({
543            let events = events.clone();
544            async move {
545                response.await.unwrap();
546                events.lock().push("response".to_string());
547            }
548        });
549
550        executor
551            .spawn({
552                let events = events.clone();
553                async move {
554                    let incoming1 = client_incoming
555                        .next()
556                        .await
557                        .unwrap()
558                        .into_any()
559                        .downcast::<TypedEnvelope<proto::Error>>()
560                        .unwrap();
561                    events.lock().push(incoming1.payload.message);
562                    let incoming2 = client_incoming
563                        .next()
564                        .await
565                        .unwrap()
566                        .into_any()
567                        .downcast::<TypedEnvelope<proto::Error>>()
568                        .unwrap();
569                    events.lock().push(incoming2.payload.message);
570
571                    // Prevent the connection from being dropped
572                    client_incoming.next().await;
573                }
574            })
575            .detach();
576
577        response_task.await;
578        assert_eq!(
579            &*events.lock(),
580            &[
581                "message 1".to_string(),
582                "message 2".to_string(),
583                "response".to_string()
584            ]
585        );
586    }
587
588    #[gpui::test(iterations = 50)]
589    async fn test_dropping_request_before_completion(cx: TestAppContext) {
590        let executor = cx.foreground();
591        let server = Peer::new();
592        let client = Peer::new();
593
594        let (client_to_server_conn, server_to_client_conn, _) =
595            Connection::in_memory(cx.background());
596        let (client_to_server_conn_id, io_task1, mut client_incoming) =
597            client.add_connection(client_to_server_conn).await;
598        let (server_to_client_conn_id, io_task2, mut server_incoming) =
599            server.add_connection(server_to_client_conn).await;
600
601        executor.spawn(io_task1).detach();
602        executor.spawn(io_task2).detach();
603
604        executor
605            .spawn(async move {
606                let request1 = server_incoming
607                    .next()
608                    .await
609                    .unwrap()
610                    .into_any()
611                    .downcast::<TypedEnvelope<proto::Ping>>()
612                    .unwrap();
613                let request2 = server_incoming
614                    .next()
615                    .await
616                    .unwrap()
617                    .into_any()
618                    .downcast::<TypedEnvelope<proto::Ping>>()
619                    .unwrap();
620
621                server
622                    .send(
623                        server_to_client_conn_id,
624                        proto::Error {
625                            message: "message 1".to_string(),
626                        },
627                    )
628                    .unwrap();
629                server
630                    .send(
631                        server_to_client_conn_id,
632                        proto::Error {
633                            message: "message 2".to_string(),
634                        },
635                    )
636                    .unwrap();
637                server.respond(request1.receipt(), proto::Ack {}).unwrap();
638                server.respond(request2.receipt(), proto::Ack {}).unwrap();
639
640                // Prevent the connection from being dropped
641                server_incoming.next().await;
642            })
643            .detach();
644
645        let events = Arc::new(Mutex::new(Vec::new()));
646
647        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
648        let request1_task = executor.spawn(request1);
649        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
650        let request2_task = executor.spawn({
651            let events = events.clone();
652            async move {
653                request2.await.unwrap();
654                events.lock().push("response 2".to_string());
655            }
656        });
657
658        executor
659            .spawn({
660                let events = events.clone();
661                async move {
662                    let incoming1 = client_incoming
663                        .next()
664                        .await
665                        .unwrap()
666                        .into_any()
667                        .downcast::<TypedEnvelope<proto::Error>>()
668                        .unwrap();
669                    events.lock().push(incoming1.payload.message);
670                    let incoming2 = client_incoming
671                        .next()
672                        .await
673                        .unwrap()
674                        .into_any()
675                        .downcast::<TypedEnvelope<proto::Error>>()
676                        .unwrap();
677                    events.lock().push(incoming2.payload.message);
678
679                    // Prevent the connection from being dropped
680                    client_incoming.next().await;
681                }
682            })
683            .detach();
684
685        // Allow the request to make some progress before dropping it.
686        cx.background().simulate_random_delay().await;
687        drop(request1_task);
688
689        request2_task.await;
690        assert_eq!(
691            &*events.lock(),
692            &[
693                "message 1".to_string(),
694                "message 2".to_string(),
695                "response 2".to_string()
696            ]
697        );
698    }
699
700    #[gpui::test(iterations = 50)]
701    async fn test_disconnect(cx: TestAppContext) {
702        let executor = cx.foreground();
703
704        let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
705
706        let client = Peer::new();
707        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
708
709        let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
710        executor
711            .spawn(async move {
712                io_handler.await.ok();
713                io_ended_tx.send(()).await.unwrap();
714            })
715            .detach();
716
717        let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
718        executor
719            .spawn(async move {
720                incoming.next().await;
721                messages_ended_tx.send(()).await.unwrap();
722            })
723            .detach();
724
725        client.disconnect(connection_id);
726
727        io_ended_rx.recv().await;
728        messages_ended_rx.recv().await;
729        assert!(server_conn
730            .send(WebSocketMessage::Binary(vec![]))
731            .await
732            .is_err());
733    }
734
735    #[gpui::test(iterations = 50)]
736    async fn test_io_error(cx: TestAppContext) {
737        let executor = cx.foreground();
738        let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
739
740        let client = Peer::new();
741        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
742        executor.spawn(io_handler).detach();
743        executor
744            .spawn(async move { incoming.next().await })
745            .detach();
746
747        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
748        let _request = server_conn.rx.next().await.unwrap().unwrap();
749
750        drop(server_conn);
751        assert_eq!(
752            response.await.unwrap_err().to_string(),
753            "connection was closed"
754        );
755    }
756}