peer.rs

  1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use super::Connection;
  3use anyhow::{anyhow, Context, Result};
  4use futures::stream::BoxStream;
  5use futures::{FutureExt as _, StreamExt};
  6use parking_lot::{Mutex, RwLock};
  7use postage::{
  8    mpsc,
  9    prelude::{Sink as _, Stream as _},
 10};
 11use smol_timeout::TimeoutExt as _;
 12use std::sync::atomic::Ordering::SeqCst;
 13use std::{
 14    collections::HashMap,
 15    fmt,
 16    future::Future,
 17    marker::PhantomData,
 18    sync::{
 19        atomic::{self, AtomicU32},
 20        Arc,
 21    },
 22    time::Duration,
 23};
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 26pub struct ConnectionId(pub u32);
 27
 28impl fmt::Display for ConnectionId {
 29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 30        self.0.fmt(f)
 31    }
 32}
 33
 34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 35pub struct PeerId(pub u32);
 36
 37impl fmt::Display for PeerId {
 38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 39        self.0.fmt(f)
 40    }
 41}
 42
 43pub struct Receipt<T> {
 44    pub sender_id: ConnectionId,
 45    pub message_id: u32,
 46    payload_type: PhantomData<T>,
 47}
 48
 49impl<T> Clone for Receipt<T> {
 50    fn clone(&self) -> Self {
 51        Self {
 52            sender_id: self.sender_id,
 53            message_id: self.message_id,
 54            payload_type: PhantomData,
 55        }
 56    }
 57}
 58
 59impl<T> Copy for Receipt<T> {}
 60
 61pub struct TypedEnvelope<T> {
 62    pub sender_id: ConnectionId,
 63    pub original_sender_id: Option<PeerId>,
 64    pub message_id: u32,
 65    pub payload: T,
 66}
 67
 68impl<T> TypedEnvelope<T> {
 69    pub fn original_sender_id(&self) -> Result<PeerId> {
 70        self.original_sender_id
 71            .ok_or_else(|| anyhow!("missing original_sender_id"))
 72    }
 73}
 74
 75impl<T: RequestMessage> TypedEnvelope<T> {
 76    pub fn receipt(&self) -> Receipt<T> {
 77        Receipt {
 78            sender_id: self.sender_id,
 79            message_id: self.message_id,
 80            payload_type: PhantomData,
 81        }
 82    }
 83}
 84
 85pub struct Peer {
 86    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 87    next_connection_id: AtomicU32,
 88}
 89
 90#[derive(Clone)]
 91pub struct ConnectionState {
 92    outgoing_tx: mpsc::Sender<proto::Envelope>,
 93    next_message_id: Arc<AtomicU32>,
 94    response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
 95}
 96
 97const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
 98
 99impl Peer {
100    pub fn new() -> Arc<Self> {
101        Arc::new(Self {
102            connections: Default::default(),
103            next_connection_id: Default::default(),
104        })
105    }
106
107    pub async fn add_connection(
108        self: &Arc<Self>,
109        connection: Connection,
110    ) -> (
111        ConnectionId,
112        impl Future<Output = anyhow::Result<()>> + Send,
113        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
114    ) {
115        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
116        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
117        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
118        let connection_state = ConnectionState {
119            outgoing_tx,
120            next_message_id: Default::default(),
121            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
122        };
123        let mut writer = MessageStream::new(connection.tx);
124        let mut reader = MessageStream::new(connection.rx);
125
126        let this = self.clone();
127        let response_channels = connection_state.response_channels.clone();
128        let handle_io = async move {
129            let result = 'outer: loop {
130                let read_message = reader.read_message().fuse();
131                futures::pin_mut!(read_message);
132                loop {
133                    futures::select_biased! {
134                        incoming = read_message => match incoming {
135                            Ok(incoming) => {
136                                if incoming_tx.send(incoming).await.is_err() {
137                                    break 'outer Ok(());
138                                }
139                                break;
140                            }
141                            Err(error) => {
142                                break 'outer Err(error).context("received invalid RPC message")
143                            }
144                        },
145                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
146                            Some(outgoing) => {
147                                match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
148                                    None => break 'outer Err(anyhow!("timed out writing RPC message")),
149                                    Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
150                                    _ => {}
151                                }
152                            }
153                            None => break 'outer Ok(()),
154                        }
155                    }
156                }
157            };
158
159            response_channels.lock().take();
160            this.connections.write().remove(&connection_id);
161            result
162        };
163
164        let response_channels = connection_state.response_channels.clone();
165        self.connections
166            .write()
167            .insert(connection_id, connection_state);
168
169        let incoming_rx = incoming_rx.filter_map(move |incoming| {
170            let response_channels = response_channels.clone();
171            async move {
172                if let Some(responding_to) = incoming.responding_to {
173                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
174                    if let Some(mut tx) = channel {
175                        tx.send(incoming).await.ok();
176                    } else {
177                        log::warn!("received RPC response to unknown request {}", responding_to);
178                    }
179
180                    None
181                } else {
182                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
183                        Some(envelope)
184                    } else {
185                        log::error!("unable to construct a typed envelope");
186                        None
187                    }
188                }
189            }
190        });
191        (connection_id, handle_io, incoming_rx.boxed())
192    }
193
194    pub fn disconnect(&self, connection_id: ConnectionId) {
195        self.connections.write().remove(&connection_id);
196    }
197
198    pub fn reset(&self) {
199        self.connections.write().clear();
200    }
201
202    pub fn request<T: RequestMessage>(
203        self: &Arc<Self>,
204        receiver_id: ConnectionId,
205        request: T,
206    ) -> impl Future<Output = Result<T::Response>> {
207        self.request_internal(None, receiver_id, request)
208    }
209
210    pub fn forward_request<T: RequestMessage>(
211        self: &Arc<Self>,
212        sender_id: ConnectionId,
213        receiver_id: ConnectionId,
214        request: T,
215    ) -> impl Future<Output = Result<T::Response>> {
216        self.request_internal(Some(sender_id), receiver_id, request)
217    }
218
219    pub fn request_internal<T: RequestMessage>(
220        self: &Arc<Self>,
221        original_sender_id: Option<ConnectionId>,
222        receiver_id: ConnectionId,
223        request: T,
224    ) -> impl Future<Output = Result<T::Response>> {
225        let this = self.clone();
226        let (tx, mut rx) = mpsc::channel(1);
227        async move {
228            let mut connection = this.connection_state(receiver_id)?;
229            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
230            connection
231                .response_channels
232                .lock()
233                .as_mut()
234                .ok_or_else(|| anyhow!("connection was closed"))?
235                .insert(message_id, tx);
236            connection
237                .outgoing_tx
238                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
239                .await
240                .map_err(|_| anyhow!("connection was closed"))?;
241            let response = rx
242                .recv()
243                .await
244                .ok_or_else(|| anyhow!("connection was closed"))?;
245            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
246                Err(anyhow!("request failed").context(error.message.clone()))
247            } else {
248                T::Response::from_envelope(response)
249                    .ok_or_else(|| anyhow!("received response of the wrong type"))
250            }
251        }
252    }
253
254    pub fn send<T: EnvelopedMessage>(
255        self: &Arc<Self>,
256        receiver_id: ConnectionId,
257        message: T,
258    ) -> impl Future<Output = Result<()>> {
259        let this = self.clone();
260        async move {
261            let mut connection = this.connection_state(receiver_id)?;
262            let message_id = connection
263                .next_message_id
264                .fetch_add(1, atomic::Ordering::SeqCst);
265            connection
266                .outgoing_tx
267                .send(message.into_envelope(message_id, None, None))
268                .await?;
269            Ok(())
270        }
271    }
272
273    pub fn forward_send<T: EnvelopedMessage>(
274        self: &Arc<Self>,
275        sender_id: ConnectionId,
276        receiver_id: ConnectionId,
277        message: T,
278    ) -> impl Future<Output = Result<()>> {
279        let this = self.clone();
280        async move {
281            let mut connection = this.connection_state(receiver_id)?;
282            let message_id = connection
283                .next_message_id
284                .fetch_add(1, atomic::Ordering::SeqCst);
285            connection
286                .outgoing_tx
287                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
288                .await?;
289            Ok(())
290        }
291    }
292
293    pub fn respond<T: RequestMessage>(
294        self: &Arc<Self>,
295        receipt: Receipt<T>,
296        response: T::Response,
297    ) -> impl Future<Output = Result<()>> {
298        let this = self.clone();
299        async move {
300            let mut connection = this.connection_state(receipt.sender_id)?;
301            let message_id = connection
302                .next_message_id
303                .fetch_add(1, atomic::Ordering::SeqCst);
304            connection
305                .outgoing_tx
306                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
307                .await?;
308            Ok(())
309        }
310    }
311
312    pub fn respond_with_error<T: RequestMessage>(
313        self: &Arc<Self>,
314        receipt: Receipt<T>,
315        response: proto::Error,
316    ) -> impl Future<Output = Result<()>> {
317        let this = self.clone();
318        async move {
319            let mut connection = this.connection_state(receipt.sender_id)?;
320            let message_id = connection
321                .next_message_id
322                .fetch_add(1, atomic::Ordering::SeqCst);
323            connection
324                .outgoing_tx
325                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
326                .await?;
327            Ok(())
328        }
329    }
330
331    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
332        let connections = self.connections.read();
333        let connection = connections
334            .get(&connection_id)
335            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
336        Ok(connection.clone())
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use crate::TypedEnvelope;
344    use async_tungstenite::tungstenite::Message as WebSocketMessage;
345    use gpui::TestAppContext;
346
347    #[gpui::test(iterations = 10)]
348    async fn test_request_response(cx: TestAppContext) {
349        let executor = cx.foreground();
350
351        // create 2 clients connected to 1 server
352        let server = Peer::new();
353        let client1 = Peer::new();
354        let client2 = Peer::new();
355
356        let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
357        let (client1_conn_id, io_task1, client1_incoming) =
358            client1.add_connection(client1_to_server_conn).await;
359        let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
360
361        let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
362        let (client2_conn_id, io_task3, client2_incoming) =
363            client2.add_connection(client2_to_server_conn).await;
364        let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
365
366        executor.spawn(io_task1).detach();
367        executor.spawn(io_task2).detach();
368        executor.spawn(io_task3).detach();
369        executor.spawn(io_task4).detach();
370        executor
371            .spawn(handle_messages(server_incoming1, server.clone()))
372            .detach();
373        executor
374            .spawn(handle_messages(client1_incoming, client1.clone()))
375            .detach();
376        executor
377            .spawn(handle_messages(server_incoming2, server.clone()))
378            .detach();
379        executor
380            .spawn(handle_messages(client2_incoming, client2.clone()))
381            .detach();
382
383        assert_eq!(
384            client1
385                .request(client1_conn_id, proto::Ping {},)
386                .await
387                .unwrap(),
388            proto::Ack {}
389        );
390
391        assert_eq!(
392            client2
393                .request(client2_conn_id, proto::Ping {},)
394                .await
395                .unwrap(),
396            proto::Ack {}
397        );
398
399        assert_eq!(
400            client1
401                .request(
402                    client1_conn_id,
403                    proto::OpenBuffer {
404                        project_id: 0,
405                        worktree_id: 1,
406                        path: "path/one".to_string(),
407                    },
408                )
409                .await
410                .unwrap(),
411            proto::OpenBufferResponse {
412                buffer: Some(proto::Buffer {
413                    id: 101,
414                    visible_text: "path/one content".to_string(),
415                    ..Default::default()
416                }),
417            }
418        );
419
420        assert_eq!(
421            client2
422                .request(
423                    client2_conn_id,
424                    proto::OpenBuffer {
425                        project_id: 0,
426                        worktree_id: 2,
427                        path: "path/two".to_string(),
428                    },
429                )
430                .await
431                .unwrap(),
432            proto::OpenBufferResponse {
433                buffer: Some(proto::Buffer {
434                    id: 102,
435                    visible_text: "path/two content".to_string(),
436                    ..Default::default()
437                }),
438            }
439        );
440
441        client1.disconnect(client1_conn_id);
442        client2.disconnect(client1_conn_id);
443
444        async fn handle_messages(
445            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
446            peer: Arc<Peer>,
447        ) -> Result<()> {
448            while let Some(envelope) = messages.next().await {
449                let envelope = envelope.into_any();
450                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
451                    let receipt = envelope.receipt();
452                    peer.respond(receipt, proto::Ack {}).await?
453                } else if let Some(envelope) =
454                    envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
455                {
456                    let message = &envelope.payload;
457                    let receipt = envelope.receipt();
458                    let response = match message.path.as_str() {
459                        "path/one" => {
460                            assert_eq!(message.worktree_id, 1);
461                            proto::OpenBufferResponse {
462                                buffer: Some(proto::Buffer {
463                                    id: 101,
464                                    visible_text: "path/one content".to_string(),
465                                    ..Default::default()
466                                }),
467                            }
468                        }
469                        "path/two" => {
470                            assert_eq!(message.worktree_id, 2);
471                            proto::OpenBufferResponse {
472                                buffer: Some(proto::Buffer {
473                                    id: 102,
474                                    visible_text: "path/two content".to_string(),
475                                    ..Default::default()
476                                }),
477                            }
478                        }
479                        _ => {
480                            panic!("unexpected path {}", message.path);
481                        }
482                    };
483
484                    peer.respond(receipt, response).await?
485                } else {
486                    panic!("unknown message type");
487                }
488            }
489
490            Ok(())
491        }
492    }
493
494    #[gpui::test(iterations = 10)]
495    async fn test_order_of_response_and_incoming(cx: TestAppContext) {
496        let executor = cx.foreground();
497        let server = Peer::new();
498        let client = Peer::new();
499
500        let (client_to_server_conn, server_to_client_conn, _) = Connection::in_memory();
501        let (client_to_server_conn_id, io_task1, mut client_incoming) =
502            client.add_connection(client_to_server_conn).await;
503        let (server_to_client_conn_id, io_task2, mut server_incoming) =
504            server.add_connection(server_to_client_conn).await;
505
506        executor.spawn(io_task1).detach();
507        executor.spawn(io_task2).detach();
508
509        executor
510            .spawn(async move {
511                let request = server_incoming
512                    .next()
513                    .await
514                    .unwrap()
515                    .into_any()
516                    .downcast::<TypedEnvelope<proto::Ping>>()
517                    .unwrap();
518
519                server
520                    .send(
521                        server_to_client_conn_id,
522                        proto::Error {
523                            message: "message 1".to_string(),
524                        },
525                    )
526                    .await
527                    .unwrap();
528                server
529                    .send(
530                        server_to_client_conn_id,
531                        proto::Error {
532                            message: "message 2".to_string(),
533                        },
534                    )
535                    .await
536                    .unwrap();
537                server
538                    .respond(request.receipt(), proto::Ack {})
539                    .await
540                    .unwrap();
541
542                // Prevent the connection from being dropped
543                server_incoming.next().await;
544            })
545            .detach();
546
547        let events = Arc::new(Mutex::new(Vec::new()));
548
549        let response = client.request(client_to_server_conn_id, proto::Ping {});
550        let response_task = executor.spawn({
551            let events = events.clone();
552            async move {
553                response.await.unwrap();
554                events.lock().push("response".to_string());
555            }
556        });
557
558        executor
559            .spawn({
560                let events = events.clone();
561                async move {
562                    let incoming1 = client_incoming
563                        .next()
564                        .await
565                        .unwrap()
566                        .into_any()
567                        .downcast::<TypedEnvelope<proto::Error>>()
568                        .unwrap();
569                    events.lock().push(incoming1.payload.message);
570                    let incoming2 = client_incoming
571                        .next()
572                        .await
573                        .unwrap()
574                        .into_any()
575                        .downcast::<TypedEnvelope<proto::Error>>()
576                        .unwrap();
577                    events.lock().push(incoming2.payload.message);
578
579                    // Prevent the connection from being dropped
580                    client_incoming.next().await;
581                }
582            })
583            .detach();
584
585        response_task.await;
586        assert_eq!(
587            &*events.lock(),
588            &[
589                "message 1".to_string(),
590                "message 2".to_string(),
591                "response".to_string()
592            ]
593        );
594    }
595
596    #[gpui::test(iterations = 10)]
597    async fn test_disconnect(cx: TestAppContext) {
598        let executor = cx.foreground();
599
600        let (client_conn, mut server_conn, _) = Connection::in_memory();
601
602        let client = Peer::new();
603        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
604
605        let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
606        executor
607            .spawn(async move {
608                io_handler.await.ok();
609                io_ended_tx.send(()).await.unwrap();
610            })
611            .detach();
612
613        let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
614        executor
615            .spawn(async move {
616                incoming.next().await;
617                messages_ended_tx.send(()).await.unwrap();
618            })
619            .detach();
620
621        client.disconnect(connection_id);
622
623        io_ended_rx.recv().await;
624        messages_ended_rx.recv().await;
625        assert!(server_conn
626            .send(WebSocketMessage::Binary(vec![]))
627            .await
628            .is_err());
629    }
630
631    #[gpui::test(iterations = 10)]
632    async fn test_io_error(cx: TestAppContext) {
633        let executor = cx.foreground();
634        let (client_conn, mut server_conn, _) = Connection::in_memory();
635
636        let client = Peer::new();
637        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
638        executor.spawn(io_handler).detach();
639        executor
640            .spawn(async move { incoming.next().await })
641            .detach();
642
643        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
644        let _request = server_conn.rx.next().await.unwrap().unwrap();
645
646        drop(server_conn);
647        assert_eq!(
648            response.await.unwrap_err().to_string(),
649            "connection was closed"
650        );
651    }
652}