peer.rs

  1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use super::Connection;
  3use anyhow::{anyhow, Context, Result};
  4use futures::stream::BoxStream;
  5use futures::{FutureExt as _, StreamExt};
  6use parking_lot::{Mutex, RwLock};
  7use postage::{
  8    barrier, mpsc,
  9    prelude::{Sink as _, Stream as _},
 10};
 11use smol_timeout::TimeoutExt as _;
 12use std::sync::atomic::Ordering::SeqCst;
 13use std::{
 14    collections::HashMap,
 15    fmt,
 16    future::Future,
 17    marker::PhantomData,
 18    sync::{
 19        atomic::{self, AtomicU32},
 20        Arc,
 21    },
 22    time::Duration,
 23};
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 26pub struct ConnectionId(pub u32);
 27
 28impl fmt::Display for ConnectionId {
 29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 30        self.0.fmt(f)
 31    }
 32}
 33
 34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 35pub struct PeerId(pub u32);
 36
 37impl fmt::Display for PeerId {
 38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 39        self.0.fmt(f)
 40    }
 41}
 42
 43pub struct Receipt<T> {
 44    pub sender_id: ConnectionId,
 45    pub message_id: u32,
 46    payload_type: PhantomData<T>,
 47}
 48
 49impl<T> Clone for Receipt<T> {
 50    fn clone(&self) -> Self {
 51        Self {
 52            sender_id: self.sender_id,
 53            message_id: self.message_id,
 54            payload_type: PhantomData,
 55        }
 56    }
 57}
 58
 59impl<T> Copy for Receipt<T> {}
 60
 61pub struct TypedEnvelope<T> {
 62    pub sender_id: ConnectionId,
 63    pub original_sender_id: Option<PeerId>,
 64    pub message_id: u32,
 65    pub payload: T,
 66}
 67
 68impl<T> TypedEnvelope<T> {
 69    pub fn original_sender_id(&self) -> Result<PeerId> {
 70        self.original_sender_id
 71            .ok_or_else(|| anyhow!("missing original_sender_id"))
 72    }
 73}
 74
 75impl<T: RequestMessage> TypedEnvelope<T> {
 76    pub fn receipt(&self) -> Receipt<T> {
 77        Receipt {
 78            sender_id: self.sender_id,
 79            message_id: self.message_id,
 80            payload_type: PhantomData,
 81        }
 82    }
 83}
 84
 85pub struct Peer {
 86    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 87    next_connection_id: AtomicU32,
 88}
 89
 90#[derive(Clone)]
 91pub struct ConnectionState {
 92    outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
 93    next_message_id: Arc<AtomicU32>,
 94    response_channels:
 95        Arc<Mutex<Option<HashMap<u32, mpsc::Sender<(proto::Envelope, barrier::Sender)>>>>>,
 96}
 97
 98const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
 99
100impl Peer {
101    pub fn new() -> Arc<Self> {
102        Arc::new(Self {
103            connections: Default::default(),
104            next_connection_id: Default::default(),
105        })
106    }
107
108    pub async fn add_connection(
109        self: &Arc<Self>,
110        connection: Connection,
111    ) -> (
112        ConnectionId,
113        impl Future<Output = anyhow::Result<()>> + Send,
114        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
115    ) {
116        // For outgoing messages, use an unbounded channel so that application code
117        // can always send messages without yielding. For incoming messages, use a
118        // bounded channel so that other peers will receive backpressure if they send
119        // messages faster than this peer can process them.
120        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
121        let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
122
123        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
124        let connection_state = ConnectionState {
125            outgoing_tx,
126            next_message_id: Default::default(),
127            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
128        };
129        let mut writer = MessageStream::new(connection.tx);
130        let mut reader = MessageStream::new(connection.rx);
131
132        let this = self.clone();
133        let response_channels = connection_state.response_channels.clone();
134        let handle_io = async move {
135            let result = 'outer: loop {
136                let read_message = reader.read_message().fuse();
137                futures::pin_mut!(read_message);
138                loop {
139                    futures::select_biased! {
140                        outgoing = outgoing_rx.next().fuse() => match outgoing {
141                            Some(outgoing) => {
142                                match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
143                                    None => break 'outer Err(anyhow!("timed out writing RPC message")),
144                                    Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
145                                    _ => {}
146                                }
147                            }
148                            None => break 'outer Ok(()),
149                        },
150                        incoming = read_message => match incoming {
151                            Ok(incoming) => {
152                                if incoming_tx.send(incoming).await.is_err() {
153                                    break 'outer Ok(());
154                                }
155                                break;
156                            }
157                            Err(error) => {
158                                break 'outer Err(error).context("received invalid RPC message")
159                            }
160                        },
161                    }
162                }
163            };
164
165            response_channels.lock().take();
166            this.connections.write().remove(&connection_id);
167            result
168        };
169
170        let response_channels = connection_state.response_channels.clone();
171        self.connections
172            .write()
173            .insert(connection_id, connection_state);
174
175        let incoming_rx = incoming_rx.filter_map(move |incoming| {
176            let response_channels = response_channels.clone();
177            async move {
178                if let Some(responding_to) = incoming.responding_to {
179                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
180                    if let Some(mut tx) = channel {
181                        let mut requester_resumed = barrier::channel();
182                        tx.send((incoming, requester_resumed.0)).await.ok();
183                        // Drop response channel before awaiting on the barrier. This allows the
184                        // barrier to get dropped even if the request's future is dropped before it
185                        // has a chance to observe the response.
186                        drop(tx);
187                        requester_resumed.1.recv().await;
188                    } else {
189                        log::warn!("received RPC response to unknown request {}", responding_to);
190                    }
191
192                    None
193                } else {
194                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
195                        Some(envelope)
196                    } else {
197                        log::error!("unable to construct a typed envelope");
198                        None
199                    }
200                }
201            }
202        });
203        (connection_id, handle_io, incoming_rx.boxed())
204    }
205
206    pub fn disconnect(&self, connection_id: ConnectionId) {
207        self.connections.write().remove(&connection_id);
208    }
209
210    pub fn reset(&self) {
211        self.connections.write().clear();
212    }
213
214    pub fn request<T: RequestMessage>(
215        &self,
216        receiver_id: ConnectionId,
217        request: T,
218    ) -> impl Future<Output = Result<T::Response>> {
219        self.request_internal(None, receiver_id, request)
220    }
221
222    pub fn forward_request<T: RequestMessage>(
223        &self,
224        sender_id: ConnectionId,
225        receiver_id: ConnectionId,
226        request: T,
227    ) -> impl Future<Output = Result<T::Response>> {
228        self.request_internal(Some(sender_id), receiver_id, request)
229    }
230
231    pub fn request_internal<T: RequestMessage>(
232        &self,
233        original_sender_id: Option<ConnectionId>,
234        receiver_id: ConnectionId,
235        request: T,
236    ) -> impl Future<Output = Result<T::Response>> {
237        let (tx, mut rx) = mpsc::channel(1);
238        let send = self.connection_state(receiver_id).and_then(|connection| {
239            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
240            connection
241                .response_channels
242                .lock()
243                .as_mut()
244                .ok_or_else(|| anyhow!("connection was closed"))?
245                .insert(message_id, tx);
246            connection
247                .outgoing_tx
248                .unbounded_send(request.into_envelope(
249                    message_id,
250                    None,
251                    original_sender_id.map(|id| id.0),
252                ))
253                .map_err(|_| anyhow!("connection was closed"))?;
254            Ok(())
255        });
256        async move {
257            send?;
258            let (response, _barrier) = rx
259                .recv()
260                .await
261                .ok_or_else(|| anyhow!("connection was closed"))?;
262            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
263                Err(anyhow!("request failed").context(error.message.clone()))
264            } else {
265                T::Response::from_envelope(response)
266                    .ok_or_else(|| anyhow!("received response of the wrong type"))
267            }
268        }
269    }
270
271    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
272        let connection = self.connection_state(receiver_id)?;
273        let message_id = connection
274            .next_message_id
275            .fetch_add(1, atomic::Ordering::SeqCst);
276        connection
277            .outgoing_tx
278            .unbounded_send(message.into_envelope(message_id, None, None))?;
279        Ok(())
280    }
281
282    pub fn forward_send<T: EnvelopedMessage>(
283        &self,
284        sender_id: ConnectionId,
285        receiver_id: ConnectionId,
286        message: T,
287    ) -> Result<()> {
288        let connection = self.connection_state(receiver_id)?;
289        let message_id = connection
290            .next_message_id
291            .fetch_add(1, atomic::Ordering::SeqCst);
292        connection
293            .outgoing_tx
294            .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
295        Ok(())
296    }
297
298    pub fn respond<T: RequestMessage>(
299        &self,
300        receipt: Receipt<T>,
301        response: T::Response,
302    ) -> Result<()> {
303        let connection = self.connection_state(receipt.sender_id)?;
304        let message_id = connection
305            .next_message_id
306            .fetch_add(1, atomic::Ordering::SeqCst);
307        connection
308            .outgoing_tx
309            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
310        Ok(())
311    }
312
313    pub fn respond_with_error<T: RequestMessage>(
314        &self,
315        receipt: Receipt<T>,
316        response: proto::Error,
317    ) -> Result<()> {
318        let connection = self.connection_state(receipt.sender_id)?;
319        let message_id = connection
320            .next_message_id
321            .fetch_add(1, atomic::Ordering::SeqCst);
322        connection
323            .outgoing_tx
324            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
325        Ok(())
326    }
327
328    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
329        let connections = self.connections.read();
330        let connection = connections
331            .get(&connection_id)
332            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
333        Ok(connection.clone())
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use crate::TypedEnvelope;
341    use async_tungstenite::tungstenite::Message as WebSocketMessage;
342    use gpui::TestAppContext;
343
344    #[gpui::test(iterations = 50)]
345    async fn test_request_response(cx: TestAppContext) {
346        let executor = cx.foreground();
347
348        // create 2 clients connected to 1 server
349        let server = Peer::new();
350        let client1 = Peer::new();
351        let client2 = Peer::new();
352
353        let (client1_to_server_conn, server_to_client_1_conn, _) =
354            Connection::in_memory(cx.background());
355        let (client1_conn_id, io_task1, client1_incoming) =
356            client1.add_connection(client1_to_server_conn).await;
357        let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
358
359        let (client2_to_server_conn, server_to_client_2_conn, _) =
360            Connection::in_memory(cx.background());
361        let (client2_conn_id, io_task3, client2_incoming) =
362            client2.add_connection(client2_to_server_conn).await;
363        let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
364
365        executor.spawn(io_task1).detach();
366        executor.spawn(io_task2).detach();
367        executor.spawn(io_task3).detach();
368        executor.spawn(io_task4).detach();
369        executor
370            .spawn(handle_messages(server_incoming1, server.clone()))
371            .detach();
372        executor
373            .spawn(handle_messages(client1_incoming, client1.clone()))
374            .detach();
375        executor
376            .spawn(handle_messages(server_incoming2, server.clone()))
377            .detach();
378        executor
379            .spawn(handle_messages(client2_incoming, client2.clone()))
380            .detach();
381
382        assert_eq!(
383            client1
384                .request(client1_conn_id, proto::Ping {},)
385                .await
386                .unwrap(),
387            proto::Ack {}
388        );
389
390        assert_eq!(
391            client2
392                .request(client2_conn_id, proto::Ping {},)
393                .await
394                .unwrap(),
395            proto::Ack {}
396        );
397
398        assert_eq!(
399            client1
400                .request(
401                    client1_conn_id,
402                    proto::OpenBuffer {
403                        project_id: 0,
404                        worktree_id: 1,
405                        path: "path/one".to_string(),
406                    },
407                )
408                .await
409                .unwrap(),
410            proto::OpenBufferResponse {
411                buffer: Some(proto::Buffer {
412                    variant: Some(proto::buffer::Variant::Id(0))
413                }),
414            }
415        );
416
417        assert_eq!(
418            client2
419                .request(
420                    client2_conn_id,
421                    proto::OpenBuffer {
422                        project_id: 0,
423                        worktree_id: 2,
424                        path: "path/two".to_string(),
425                    },
426                )
427                .await
428                .unwrap(),
429            proto::OpenBufferResponse {
430                buffer: Some(proto::Buffer {
431                    variant: Some(proto::buffer::Variant::Id(1))
432                })
433            }
434        );
435
436        client1.disconnect(client1_conn_id);
437        client2.disconnect(client1_conn_id);
438
439        async fn handle_messages(
440            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
441            peer: Arc<Peer>,
442        ) -> Result<()> {
443            while let Some(envelope) = messages.next().await {
444                let envelope = envelope.into_any();
445                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
446                    let receipt = envelope.receipt();
447                    peer.respond(receipt, proto::Ack {})?
448                } else if let Some(envelope) =
449                    envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
450                {
451                    let message = &envelope.payload;
452                    let receipt = envelope.receipt();
453                    let response = match message.path.as_str() {
454                        "path/one" => {
455                            assert_eq!(message.worktree_id, 1);
456                            proto::OpenBufferResponse {
457                                buffer: Some(proto::Buffer {
458                                    variant: Some(proto::buffer::Variant::Id(0)),
459                                }),
460                            }
461                        }
462                        "path/two" => {
463                            assert_eq!(message.worktree_id, 2);
464                            proto::OpenBufferResponse {
465                                buffer: Some(proto::Buffer {
466                                    variant: Some(proto::buffer::Variant::Id(1)),
467                                }),
468                            }
469                        }
470                        _ => {
471                            panic!("unexpected path {}", message.path);
472                        }
473                    };
474
475                    peer.respond(receipt, response)?
476                } else {
477                    panic!("unknown message type");
478                }
479            }
480
481            Ok(())
482        }
483    }
484
485    #[gpui::test(iterations = 50)]
486    async fn test_order_of_response_and_incoming(cx: TestAppContext) {
487        let executor = cx.foreground();
488        let server = Peer::new();
489        let client = Peer::new();
490
491        let (client_to_server_conn, server_to_client_conn, _) =
492            Connection::in_memory(cx.background());
493        let (client_to_server_conn_id, io_task1, mut client_incoming) =
494            client.add_connection(client_to_server_conn).await;
495        let (server_to_client_conn_id, io_task2, mut server_incoming) =
496            server.add_connection(server_to_client_conn).await;
497
498        executor.spawn(io_task1).detach();
499        executor.spawn(io_task2).detach();
500
501        executor
502            .spawn(async move {
503                let request = server_incoming
504                    .next()
505                    .await
506                    .unwrap()
507                    .into_any()
508                    .downcast::<TypedEnvelope<proto::Ping>>()
509                    .unwrap();
510
511                server
512                    .send(
513                        server_to_client_conn_id,
514                        proto::Error {
515                            message: "message 1".to_string(),
516                        },
517                    )
518                    .unwrap();
519                server
520                    .send(
521                        server_to_client_conn_id,
522                        proto::Error {
523                            message: "message 2".to_string(),
524                        },
525                    )
526                    .unwrap();
527                server.respond(request.receipt(), proto::Ack {}).unwrap();
528
529                // Prevent the connection from being dropped
530                server_incoming.next().await;
531            })
532            .detach();
533
534        let events = Arc::new(Mutex::new(Vec::new()));
535
536        let response = client.request(client_to_server_conn_id, proto::Ping {});
537        let response_task = executor.spawn({
538            let events = events.clone();
539            async move {
540                response.await.unwrap();
541                events.lock().push("response".to_string());
542            }
543        });
544
545        executor
546            .spawn({
547                let events = events.clone();
548                async move {
549                    let incoming1 = client_incoming
550                        .next()
551                        .await
552                        .unwrap()
553                        .into_any()
554                        .downcast::<TypedEnvelope<proto::Error>>()
555                        .unwrap();
556                    events.lock().push(incoming1.payload.message);
557                    let incoming2 = client_incoming
558                        .next()
559                        .await
560                        .unwrap()
561                        .into_any()
562                        .downcast::<TypedEnvelope<proto::Error>>()
563                        .unwrap();
564                    events.lock().push(incoming2.payload.message);
565
566                    // Prevent the connection from being dropped
567                    client_incoming.next().await;
568                }
569            })
570            .detach();
571
572        response_task.await;
573        assert_eq!(
574            &*events.lock(),
575            &[
576                "message 1".to_string(),
577                "message 2".to_string(),
578                "response".to_string()
579            ]
580        );
581    }
582
583    #[gpui::test(iterations = 50)]
584    async fn test_dropping_request_before_completion(cx: TestAppContext) {
585        let executor = cx.foreground();
586        let server = Peer::new();
587        let client = Peer::new();
588
589        let (client_to_server_conn, server_to_client_conn, _) =
590            Connection::in_memory(cx.background());
591        let (client_to_server_conn_id, io_task1, mut client_incoming) =
592            client.add_connection(client_to_server_conn).await;
593        let (server_to_client_conn_id, io_task2, mut server_incoming) =
594            server.add_connection(server_to_client_conn).await;
595
596        executor.spawn(io_task1).detach();
597        executor.spawn(io_task2).detach();
598
599        executor
600            .spawn(async move {
601                let request1 = server_incoming
602                    .next()
603                    .await
604                    .unwrap()
605                    .into_any()
606                    .downcast::<TypedEnvelope<proto::Ping>>()
607                    .unwrap();
608                let request2 = server_incoming
609                    .next()
610                    .await
611                    .unwrap()
612                    .into_any()
613                    .downcast::<TypedEnvelope<proto::Ping>>()
614                    .unwrap();
615
616                server
617                    .send(
618                        server_to_client_conn_id,
619                        proto::Error {
620                            message: "message 1".to_string(),
621                        },
622                    )
623                    .unwrap();
624                server
625                    .send(
626                        server_to_client_conn_id,
627                        proto::Error {
628                            message: "message 2".to_string(),
629                        },
630                    )
631                    .unwrap();
632                server.respond(request1.receipt(), proto::Ack {}).unwrap();
633                server.respond(request2.receipt(), proto::Ack {}).unwrap();
634
635                // Prevent the connection from being dropped
636                server_incoming.next().await;
637            })
638            .detach();
639
640        let events = Arc::new(Mutex::new(Vec::new()));
641
642        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
643        let request1_task = executor.spawn(request1);
644        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
645        let request2_task = executor.spawn({
646            let events = events.clone();
647            async move {
648                request2.await.unwrap();
649                events.lock().push("response 2".to_string());
650            }
651        });
652
653        executor
654            .spawn({
655                let events = events.clone();
656                async move {
657                    let incoming1 = client_incoming
658                        .next()
659                        .await
660                        .unwrap()
661                        .into_any()
662                        .downcast::<TypedEnvelope<proto::Error>>()
663                        .unwrap();
664                    events.lock().push(incoming1.payload.message);
665                    let incoming2 = client_incoming
666                        .next()
667                        .await
668                        .unwrap()
669                        .into_any()
670                        .downcast::<TypedEnvelope<proto::Error>>()
671                        .unwrap();
672                    events.lock().push(incoming2.payload.message);
673
674                    // Prevent the connection from being dropped
675                    client_incoming.next().await;
676                }
677            })
678            .detach();
679
680        // Allow the request to make some progress before dropping it.
681        cx.background().simulate_random_delay().await;
682        drop(request1_task);
683
684        request2_task.await;
685        assert_eq!(
686            &*events.lock(),
687            &[
688                "message 1".to_string(),
689                "message 2".to_string(),
690                "response 2".to_string()
691            ]
692        );
693    }
694
695    #[gpui::test(iterations = 50)]
696    async fn test_disconnect(cx: TestAppContext) {
697        let executor = cx.foreground();
698
699        let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
700
701        let client = Peer::new();
702        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
703
704        let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
705        executor
706            .spawn(async move {
707                io_handler.await.ok();
708                io_ended_tx.send(()).await.unwrap();
709            })
710            .detach();
711
712        let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
713        executor
714            .spawn(async move {
715                incoming.next().await;
716                messages_ended_tx.send(()).await.unwrap();
717            })
718            .detach();
719
720        client.disconnect(connection_id);
721
722        io_ended_rx.recv().await;
723        messages_ended_rx.recv().await;
724        assert!(server_conn
725            .send(WebSocketMessage::Binary(vec![]))
726            .await
727            .is_err());
728    }
729
730    #[gpui::test(iterations = 50)]
731    async fn test_io_error(cx: TestAppContext) {
732        let executor = cx.foreground();
733        let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
734
735        let client = Peer::new();
736        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
737        executor.spawn(io_handler).detach();
738        executor
739            .spawn(async move { incoming.next().await })
740            .detach();
741
742        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
743        let _request = server_conn.rx.next().await.unwrap().unwrap();
744
745        drop(server_conn);
746        assert_eq!(
747            response.await.unwrap_err().to_string(),
748            "connection was closed"
749        );
750    }
751}