peer.rs

  1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use super::Connection;
  3use anyhow::{anyhow, Context, Result};
  4use futures::stream::BoxStream;
  5use futures::{FutureExt as _, StreamExt};
  6use parking_lot::{Mutex, RwLock};
  7use postage::{
  8    barrier, mpsc,
  9    prelude::{Sink as _, Stream as _},
 10};
 11use smol_timeout::TimeoutExt as _;
 12use std::sync::atomic::Ordering::SeqCst;
 13use std::{
 14    collections::HashMap,
 15    fmt,
 16    future::Future,
 17    marker::PhantomData,
 18    sync::{
 19        atomic::{self, AtomicU32},
 20        Arc,
 21    },
 22    time::Duration,
 23};
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 26pub struct ConnectionId(pub u32);
 27
 28impl fmt::Display for ConnectionId {
 29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 30        self.0.fmt(f)
 31    }
 32}
 33
 34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 35pub struct PeerId(pub u32);
 36
 37impl fmt::Display for PeerId {
 38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 39        self.0.fmt(f)
 40    }
 41}
 42
 43pub struct Receipt<T> {
 44    pub sender_id: ConnectionId,
 45    pub message_id: u32,
 46    payload_type: PhantomData<T>,
 47}
 48
 49impl<T> Clone for Receipt<T> {
 50    fn clone(&self) -> Self {
 51        Self {
 52            sender_id: self.sender_id,
 53            message_id: self.message_id,
 54            payload_type: PhantomData,
 55        }
 56    }
 57}
 58
 59impl<T> Copy for Receipt<T> {}
 60
 61pub struct TypedEnvelope<T> {
 62    pub sender_id: ConnectionId,
 63    pub original_sender_id: Option<PeerId>,
 64    pub message_id: u32,
 65    pub payload: T,
 66}
 67
 68impl<T> TypedEnvelope<T> {
 69    pub fn original_sender_id(&self) -> Result<PeerId> {
 70        self.original_sender_id
 71            .ok_or_else(|| anyhow!("missing original_sender_id"))
 72    }
 73}
 74
 75impl<T: RequestMessage> TypedEnvelope<T> {
 76    pub fn receipt(&self) -> Receipt<T> {
 77        Receipt {
 78            sender_id: self.sender_id,
 79            message_id: self.message_id,
 80            payload_type: PhantomData,
 81        }
 82    }
 83}
 84
 85pub struct Peer {
 86    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 87    next_connection_id: AtomicU32,
 88}
 89
 90#[derive(Clone)]
 91pub struct ConnectionState {
 92    outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
 93    next_message_id: Arc<AtomicU32>,
 94    response_channels:
 95        Arc<Mutex<Option<HashMap<u32, mpsc::Sender<(proto::Envelope, barrier::Sender)>>>>>,
 96}
 97
 98const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
 99
100impl Peer {
101    pub fn new() -> Arc<Self> {
102        Arc::new(Self {
103            connections: Default::default(),
104            next_connection_id: Default::default(),
105        })
106    }
107
108    pub async fn add_connection(
109        self: &Arc<Self>,
110        connection: Connection,
111    ) -> (
112        ConnectionId,
113        impl Future<Output = anyhow::Result<()>> + Send,
114        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
115    ) {
116        // For outgoing messages, use an unbounded channel so that application code
117        // can always send messages without yielding. For incoming messages, use a
118        // bounded channel so that other peers will receive backpressure if they send
119        // messages faster than this peer can process them.
120        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
121        let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
122
123        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
124        let connection_state = ConnectionState {
125            outgoing_tx,
126            next_message_id: Default::default(),
127            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
128        };
129        let mut writer = MessageStream::new(connection.tx);
130        let mut reader = MessageStream::new(connection.rx);
131
132        let this = self.clone();
133        let response_channels = connection_state.response_channels.clone();
134        let handle_io = async move {
135            let result = 'outer: loop {
136                let read_message = reader.read_message().fuse();
137                futures::pin_mut!(read_message);
138                loop {
139                    futures::select_biased! {
140                        outgoing = outgoing_rx.next().fuse() => match outgoing {
141                            Some(outgoing) => {
142                                match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
143                                    None => break 'outer Err(anyhow!("timed out writing RPC message")),
144                                    Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
145                                    _ => {}
146                                }
147                            }
148                            None => break 'outer Ok(()),
149                        },
150                        incoming = read_message => match incoming {
151                            Ok(incoming) => {
152                                if incoming_tx.send(incoming).await.is_err() {
153                                    break 'outer Ok(());
154                                }
155                                break;
156                            }
157                            Err(error) => {
158                                break 'outer Err(error).context("received invalid RPC message")
159                            }
160                        },
161                    }
162                }
163            };
164
165            response_channels.lock().take();
166            this.connections.write().remove(&connection_id);
167            result
168        };
169
170        let response_channels = connection_state.response_channels.clone();
171        self.connections
172            .write()
173            .insert(connection_id, connection_state);
174
175        let incoming_rx = incoming_rx.filter_map(move |incoming| {
176            let response_channels = response_channels.clone();
177            async move {
178                if let Some(responding_to) = incoming.responding_to {
179                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
180                    if let Some(mut tx) = channel {
181                        let mut requester_resumed = barrier::channel();
182                        tx.send((incoming, requester_resumed.0)).await.ok();
183                        requester_resumed.1.recv().await;
184                    } else {
185                        log::warn!("received RPC response to unknown request {}", responding_to);
186                    }
187
188                    None
189                } else {
190                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
191                        Some(envelope)
192                    } else {
193                        log::error!("unable to construct a typed envelope");
194                        None
195                    }
196                }
197            }
198        });
199        (connection_id, handle_io, incoming_rx.boxed())
200    }
201
202    pub fn disconnect(&self, connection_id: ConnectionId) {
203        self.connections.write().remove(&connection_id);
204    }
205
206    pub fn reset(&self) {
207        self.connections.write().clear();
208    }
209
210    pub fn request<T: RequestMessage>(
211        &self,
212        receiver_id: ConnectionId,
213        request: T,
214    ) -> impl Future<Output = Result<T::Response>> {
215        self.request_internal(None, receiver_id, request)
216    }
217
218    pub fn forward_request<T: RequestMessage>(
219        &self,
220        sender_id: ConnectionId,
221        receiver_id: ConnectionId,
222        request: T,
223    ) -> impl Future<Output = Result<T::Response>> {
224        self.request_internal(Some(sender_id), receiver_id, request)
225    }
226
227    pub fn request_internal<T: RequestMessage>(
228        &self,
229        original_sender_id: Option<ConnectionId>,
230        receiver_id: ConnectionId,
231        request: T,
232    ) -> impl Future<Output = Result<T::Response>> {
233        let (tx, mut rx) = mpsc::channel(1);
234        let send = self.connection_state(receiver_id).and_then(|connection| {
235            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
236            connection
237                .response_channels
238                .lock()
239                .as_mut()
240                .ok_or_else(|| anyhow!("connection was closed"))?
241                .insert(message_id, tx);
242            connection
243                .outgoing_tx
244                .unbounded_send(request.into_envelope(
245                    message_id,
246                    None,
247                    original_sender_id.map(|id| id.0),
248                ))
249                .map_err(|_| anyhow!("connection was closed"))?;
250            Ok(())
251        });
252        async move {
253            send?;
254            let (response, _barrier) = rx
255                .recv()
256                .await
257                .ok_or_else(|| anyhow!("connection was closed"))?;
258            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
259                Err(anyhow!("request failed").context(error.message.clone()))
260            } else {
261                T::Response::from_envelope(response)
262                    .ok_or_else(|| anyhow!("received response of the wrong type"))
263            }
264        }
265    }
266
267    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
268        let connection = self.connection_state(receiver_id)?;
269        let message_id = connection
270            .next_message_id
271            .fetch_add(1, atomic::Ordering::SeqCst);
272        connection
273            .outgoing_tx
274            .unbounded_send(message.into_envelope(message_id, None, None))?;
275        Ok(())
276    }
277
278    pub fn forward_send<T: EnvelopedMessage>(
279        &self,
280        sender_id: ConnectionId,
281        receiver_id: ConnectionId,
282        message: T,
283    ) -> Result<()> {
284        let connection = self.connection_state(receiver_id)?;
285        let message_id = connection
286            .next_message_id
287            .fetch_add(1, atomic::Ordering::SeqCst);
288        connection
289            .outgoing_tx
290            .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
291        Ok(())
292    }
293
294    pub fn respond<T: RequestMessage>(
295        &self,
296        receipt: Receipt<T>,
297        response: T::Response,
298    ) -> Result<()> {
299        let connection = self.connection_state(receipt.sender_id)?;
300        let message_id = connection
301            .next_message_id
302            .fetch_add(1, atomic::Ordering::SeqCst);
303        connection
304            .outgoing_tx
305            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
306        Ok(())
307    }
308
309    pub fn respond_with_error<T: RequestMessage>(
310        &self,
311        receipt: Receipt<T>,
312        response: proto::Error,
313    ) -> Result<()> {
314        let connection = self.connection_state(receipt.sender_id)?;
315        let message_id = connection
316            .next_message_id
317            .fetch_add(1, atomic::Ordering::SeqCst);
318        connection
319            .outgoing_tx
320            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
321        Ok(())
322    }
323
324    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
325        let connections = self.connections.read();
326        let connection = connections
327            .get(&connection_id)
328            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
329        Ok(connection.clone())
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use crate::TypedEnvelope;
337    use async_tungstenite::tungstenite::Message as WebSocketMessage;
338    use gpui::TestAppContext;
339
340    #[gpui::test(iterations = 10)]
341    async fn test_request_response(cx: TestAppContext) {
342        let executor = cx.foreground();
343
344        // create 2 clients connected to 1 server
345        let server = Peer::new();
346        let client1 = Peer::new();
347        let client2 = Peer::new();
348
349        let (client1_to_server_conn, server_to_client_1_conn, _) =
350            Connection::in_memory(cx.background());
351        let (client1_conn_id, io_task1, client1_incoming) =
352            client1.add_connection(client1_to_server_conn).await;
353        let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
354
355        let (client2_to_server_conn, server_to_client_2_conn, _) =
356            Connection::in_memory(cx.background());
357        let (client2_conn_id, io_task3, client2_incoming) =
358            client2.add_connection(client2_to_server_conn).await;
359        let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
360
361        executor.spawn(io_task1).detach();
362        executor.spawn(io_task2).detach();
363        executor.spawn(io_task3).detach();
364        executor.spawn(io_task4).detach();
365        executor
366            .spawn(handle_messages(server_incoming1, server.clone()))
367            .detach();
368        executor
369            .spawn(handle_messages(client1_incoming, client1.clone()))
370            .detach();
371        executor
372            .spawn(handle_messages(server_incoming2, server.clone()))
373            .detach();
374        executor
375            .spawn(handle_messages(client2_incoming, client2.clone()))
376            .detach();
377
378        assert_eq!(
379            client1
380                .request(client1_conn_id, proto::Ping {},)
381                .await
382                .unwrap(),
383            proto::Ack {}
384        );
385
386        assert_eq!(
387            client2
388                .request(client2_conn_id, proto::Ping {},)
389                .await
390                .unwrap(),
391            proto::Ack {}
392        );
393
394        assert_eq!(
395            client1
396                .request(
397                    client1_conn_id,
398                    proto::OpenBuffer {
399                        project_id: 0,
400                        worktree_id: 1,
401                        path: "path/one".to_string(),
402                    },
403                )
404                .await
405                .unwrap(),
406            proto::OpenBufferResponse {
407                buffer: Some(proto::Buffer {
408                    variant: Some(proto::buffer::Variant::Id(0))
409                }),
410            }
411        );
412
413        assert_eq!(
414            client2
415                .request(
416                    client2_conn_id,
417                    proto::OpenBuffer {
418                        project_id: 0,
419                        worktree_id: 2,
420                        path: "path/two".to_string(),
421                    },
422                )
423                .await
424                .unwrap(),
425            proto::OpenBufferResponse {
426                buffer: Some(proto::Buffer {
427                    variant: Some(proto::buffer::Variant::Id(1))
428                })
429            }
430        );
431
432        client1.disconnect(client1_conn_id);
433        client2.disconnect(client1_conn_id);
434
435        async fn handle_messages(
436            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
437            peer: Arc<Peer>,
438        ) -> Result<()> {
439            while let Some(envelope) = messages.next().await {
440                let envelope = envelope.into_any();
441                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
442                    let receipt = envelope.receipt();
443                    peer.respond(receipt, proto::Ack {})?
444                } else if let Some(envelope) =
445                    envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
446                {
447                    let message = &envelope.payload;
448                    let receipt = envelope.receipt();
449                    let response = match message.path.as_str() {
450                        "path/one" => {
451                            assert_eq!(message.worktree_id, 1);
452                            proto::OpenBufferResponse {
453                                buffer: Some(proto::Buffer {
454                                    variant: Some(proto::buffer::Variant::Id(0)),
455                                }),
456                            }
457                        }
458                        "path/two" => {
459                            assert_eq!(message.worktree_id, 2);
460                            proto::OpenBufferResponse {
461                                buffer: Some(proto::Buffer {
462                                    variant: Some(proto::buffer::Variant::Id(1)),
463                                }),
464                            }
465                        }
466                        _ => {
467                            panic!("unexpected path {}", message.path);
468                        }
469                    };
470
471                    peer.respond(receipt, response)?
472                } else {
473                    panic!("unknown message type");
474                }
475            }
476
477            Ok(())
478        }
479    }
480
481    #[gpui::test(iterations = 10)]
482    async fn test_order_of_response_and_incoming(cx: TestAppContext) {
483        let executor = cx.foreground();
484        let server = Peer::new();
485        let client = Peer::new();
486
487        let (client_to_server_conn, server_to_client_conn, _) =
488            Connection::in_memory(cx.background());
489        let (client_to_server_conn_id, io_task1, mut client_incoming) =
490            client.add_connection(client_to_server_conn).await;
491        let (server_to_client_conn_id, io_task2, mut server_incoming) =
492            server.add_connection(server_to_client_conn).await;
493
494        executor.spawn(io_task1).detach();
495        executor.spawn(io_task2).detach();
496
497        executor
498            .spawn(async move {
499                let request = server_incoming
500                    .next()
501                    .await
502                    .unwrap()
503                    .into_any()
504                    .downcast::<TypedEnvelope<proto::Ping>>()
505                    .unwrap();
506
507                server
508                    .send(
509                        server_to_client_conn_id,
510                        proto::Error {
511                            message: "message 1".to_string(),
512                        },
513                    )
514                    .unwrap();
515                server
516                    .send(
517                        server_to_client_conn_id,
518                        proto::Error {
519                            message: "message 2".to_string(),
520                        },
521                    )
522                    .unwrap();
523                server.respond(request.receipt(), proto::Ack {}).unwrap();
524
525                // Prevent the connection from being dropped
526                server_incoming.next().await;
527            })
528            .detach();
529
530        let events = Arc::new(Mutex::new(Vec::new()));
531
532        let response = client.request(client_to_server_conn_id, proto::Ping {});
533        let response_task = executor.spawn({
534            let events = events.clone();
535            async move {
536                response.await.unwrap();
537                events.lock().push("response".to_string());
538            }
539        });
540
541        executor
542            .spawn({
543                let events = events.clone();
544                async move {
545                    let incoming1 = client_incoming
546                        .next()
547                        .await
548                        .unwrap()
549                        .into_any()
550                        .downcast::<TypedEnvelope<proto::Error>>()
551                        .unwrap();
552                    events.lock().push(incoming1.payload.message);
553                    let incoming2 = client_incoming
554                        .next()
555                        .await
556                        .unwrap()
557                        .into_any()
558                        .downcast::<TypedEnvelope<proto::Error>>()
559                        .unwrap();
560                    events.lock().push(incoming2.payload.message);
561
562                    // Prevent the connection from being dropped
563                    client_incoming.next().await;
564                }
565            })
566            .detach();
567
568        response_task.await;
569        assert_eq!(
570            &*events.lock(),
571            &[
572                "message 1".to_string(),
573                "message 2".to_string(),
574                "response".to_string()
575            ]
576        );
577    }
578
579    #[gpui::test(iterations = 10)]
580    async fn test_disconnect(cx: TestAppContext) {
581        let executor = cx.foreground();
582
583        let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
584
585        let client = Peer::new();
586        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
587
588        let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
589        executor
590            .spawn(async move {
591                io_handler.await.ok();
592                io_ended_tx.send(()).await.unwrap();
593            })
594            .detach();
595
596        let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
597        executor
598            .spawn(async move {
599                incoming.next().await;
600                messages_ended_tx.send(()).await.unwrap();
601            })
602            .detach();
603
604        client.disconnect(connection_id);
605
606        io_ended_rx.recv().await;
607        messages_ended_rx.recv().await;
608        assert!(server_conn
609            .send(WebSocketMessage::Binary(vec![]))
610            .await
611            .is_err());
612    }
613
614    #[gpui::test(iterations = 10)]
615    async fn test_io_error(cx: TestAppContext) {
616        let executor = cx.foreground();
617        let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
618
619        let client = Peer::new();
620        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
621        executor.spawn(io_handler).detach();
622        executor
623            .spawn(async move { incoming.next().await })
624            .detach();
625
626        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
627        let _request = server_conn.rx.next().await.unwrap().unwrap();
628
629        drop(server_conn);
630        assert_eq!(
631            response.await.unwrap_err().to_string(),
632            "connection was closed"
633        );
634    }
635}