1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use futures::stream::BoxStream;
5use futures::{FutureExt as _, StreamExt};
6use parking_lot::{Mutex, RwLock};
7use postage::{
8 mpsc,
9 prelude::{Sink as _, Stream as _},
10};
11use smol_timeout::TimeoutExt as _;
12use std::sync::atomic::Ordering::SeqCst;
13use std::{
14 collections::HashMap,
15 fmt,
16 future::Future,
17 marker::PhantomData,
18 sync::{
19 atomic::{self, AtomicU32},
20 Arc,
21 },
22 time::Duration,
23};
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
26pub struct ConnectionId(pub u32);
27
28impl fmt::Display for ConnectionId {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 self.0.fmt(f)
31 }
32}
33
34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
35pub struct PeerId(pub u32);
36
37impl fmt::Display for PeerId {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 self.0.fmt(f)
40 }
41}
42
43pub struct Receipt<T> {
44 pub sender_id: ConnectionId,
45 pub message_id: u32,
46 payload_type: PhantomData<T>,
47}
48
49impl<T> Clone for Receipt<T> {
50 fn clone(&self) -> Self {
51 Self {
52 sender_id: self.sender_id,
53 message_id: self.message_id,
54 payload_type: PhantomData,
55 }
56 }
57}
58
59impl<T> Copy for Receipt<T> {}
60
61pub struct TypedEnvelope<T> {
62 pub sender_id: ConnectionId,
63 pub original_sender_id: Option<PeerId>,
64 pub message_id: u32,
65 pub payload: T,
66}
67
68impl<T> TypedEnvelope<T> {
69 pub fn original_sender_id(&self) -> Result<PeerId> {
70 self.original_sender_id
71 .ok_or_else(|| anyhow!("missing original_sender_id"))
72 }
73}
74
75impl<T: RequestMessage> TypedEnvelope<T> {
76 pub fn receipt(&self) -> Receipt<T> {
77 Receipt {
78 sender_id: self.sender_id,
79 message_id: self.message_id,
80 payload_type: PhantomData,
81 }
82 }
83}
84
85pub struct Peer {
86 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
87 next_connection_id: AtomicU32,
88}
89
90#[derive(Clone)]
91pub struct ConnectionState {
92 outgoing_tx: mpsc::Sender<proto::Envelope>,
93 next_message_id: Arc<AtomicU32>,
94 response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
95}
96
97const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
98
99impl Peer {
100 pub fn new() -> Arc<Self> {
101 Arc::new(Self {
102 connections: Default::default(),
103 next_connection_id: Default::default(),
104 })
105 }
106
107 pub async fn add_connection(
108 self: &Arc<Self>,
109 connection: Connection,
110 ) -> (
111 ConnectionId,
112 impl Future<Output = anyhow::Result<()>> + Send,
113 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
114 ) {
115 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
116 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
117 let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
118 let connection_state = ConnectionState {
119 outgoing_tx,
120 next_message_id: Default::default(),
121 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
122 };
123 let mut writer = MessageStream::new(connection.tx);
124 let mut reader = MessageStream::new(connection.rx);
125
126 let this = self.clone();
127 let response_channels = connection_state.response_channels.clone();
128 let handle_io = async move {
129 let result = 'outer: loop {
130 let read_message = reader.read_message().fuse();
131 futures::pin_mut!(read_message);
132 loop {
133 futures::select_biased! {
134 incoming = read_message => match incoming {
135 Ok(incoming) => {
136 if incoming_tx.send(incoming).await.is_err() {
137 break 'outer Ok(());
138 }
139 break;
140 }
141 Err(error) => {
142 break 'outer Err(error).context("received invalid RPC message")
143 }
144 },
145 outgoing = outgoing_rx.recv().fuse() => match outgoing {
146 Some(outgoing) => {
147 match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
148 None => break 'outer Err(anyhow!("timed out writing RPC message")),
149 Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
150 _ => {}
151 }
152 }
153 None => break 'outer Ok(()),
154 }
155 }
156 }
157 };
158
159 response_channels.lock().take();
160 this.connections.write().remove(&connection_id);
161 result
162 };
163
164 let response_channels = connection_state.response_channels.clone();
165 self.connections
166 .write()
167 .insert(connection_id, connection_state);
168
169 let incoming_rx = incoming_rx.filter_map(move |incoming| {
170 let response_channels = response_channels.clone();
171 async move {
172 if let Some(responding_to) = incoming.responding_to {
173 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
174 if let Some(mut tx) = channel {
175 tx.send(incoming).await.ok();
176 } else {
177 log::warn!("received RPC response to unknown request {}", responding_to);
178 }
179
180 None
181 } else {
182 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
183 Some(envelope)
184 } else {
185 log::error!("unable to construct a typed envelope");
186 None
187 }
188 }
189 }
190 });
191 (connection_id, handle_io, incoming_rx.boxed())
192 }
193
194 pub fn disconnect(&self, connection_id: ConnectionId) {
195 self.connections.write().remove(&connection_id);
196 }
197
198 pub fn reset(&self) {
199 self.connections.write().clear();
200 }
201
202 pub fn request<T: RequestMessage>(
203 self: &Arc<Self>,
204 receiver_id: ConnectionId,
205 request: T,
206 ) -> impl Future<Output = Result<T::Response>> {
207 self.request_internal(None, receiver_id, request)
208 }
209
210 pub fn forward_request<T: RequestMessage>(
211 self: &Arc<Self>,
212 sender_id: ConnectionId,
213 receiver_id: ConnectionId,
214 request: T,
215 ) -> impl Future<Output = Result<T::Response>> {
216 self.request_internal(Some(sender_id), receiver_id, request)
217 }
218
219 pub fn request_internal<T: RequestMessage>(
220 self: &Arc<Self>,
221 original_sender_id: Option<ConnectionId>,
222 receiver_id: ConnectionId,
223 request: T,
224 ) -> impl Future<Output = Result<T::Response>> {
225 let this = self.clone();
226 let (tx, mut rx) = mpsc::channel(1);
227 async move {
228 let mut connection = this.connection_state(receiver_id)?;
229 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
230 connection
231 .response_channels
232 .lock()
233 .as_mut()
234 .ok_or_else(|| anyhow!("connection was closed"))?
235 .insert(message_id, tx);
236 connection
237 .outgoing_tx
238 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
239 .await
240 .map_err(|_| anyhow!("connection was closed"))?;
241 let response = rx
242 .recv()
243 .await
244 .ok_or_else(|| anyhow!("connection was closed"))?;
245 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
246 Err(anyhow!("request failed").context(error.message.clone()))
247 } else {
248 T::Response::from_envelope(response)
249 .ok_or_else(|| anyhow!("received response of the wrong type"))
250 }
251 }
252 }
253
254 pub fn send<T: EnvelopedMessage>(
255 self: &Arc<Self>,
256 receiver_id: ConnectionId,
257 message: T,
258 ) -> impl Future<Output = Result<()>> {
259 let this = self.clone();
260 async move {
261 let mut connection = this.connection_state(receiver_id)?;
262 let message_id = connection
263 .next_message_id
264 .fetch_add(1, atomic::Ordering::SeqCst);
265 connection
266 .outgoing_tx
267 .send(message.into_envelope(message_id, None, None))
268 .await?;
269 Ok(())
270 }
271 }
272
273 pub fn forward_send<T: EnvelopedMessage>(
274 self: &Arc<Self>,
275 sender_id: ConnectionId,
276 receiver_id: ConnectionId,
277 message: T,
278 ) -> impl Future<Output = Result<()>> {
279 let this = self.clone();
280 async move {
281 let mut connection = this.connection_state(receiver_id)?;
282 let message_id = connection
283 .next_message_id
284 .fetch_add(1, atomic::Ordering::SeqCst);
285 connection
286 .outgoing_tx
287 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
288 .await?;
289 Ok(())
290 }
291 }
292
293 pub fn respond<T: RequestMessage>(
294 self: &Arc<Self>,
295 receipt: Receipt<T>,
296 response: T::Response,
297 ) -> impl Future<Output = Result<()>> {
298 let this = self.clone();
299 async move {
300 let mut connection = this.connection_state(receipt.sender_id)?;
301 let message_id = connection
302 .next_message_id
303 .fetch_add(1, atomic::Ordering::SeqCst);
304 connection
305 .outgoing_tx
306 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
307 .await?;
308 Ok(())
309 }
310 }
311
312 pub fn respond_with_error<T: RequestMessage>(
313 self: &Arc<Self>,
314 receipt: Receipt<T>,
315 response: proto::Error,
316 ) -> impl Future<Output = Result<()>> {
317 let this = self.clone();
318 async move {
319 let mut connection = this.connection_state(receipt.sender_id)?;
320 let message_id = connection
321 .next_message_id
322 .fetch_add(1, atomic::Ordering::SeqCst);
323 connection
324 .outgoing_tx
325 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
326 .await?;
327 Ok(())
328 }
329 }
330
331 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
332 let connections = self.connections.read();
333 let connection = connections
334 .get(&connection_id)
335 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
336 Ok(connection.clone())
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use crate::TypedEnvelope;
344 use async_tungstenite::tungstenite::Message as WebSocketMessage;
345 use gpui::TestAppContext;
346
347 #[gpui::test(iterations = 10)]
348 async fn test_request_response(cx: TestAppContext) {
349 let executor = cx.foreground();
350
351 // create 2 clients connected to 1 server
352 let server = Peer::new();
353 let client1 = Peer::new();
354 let client2 = Peer::new();
355
356 let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
357 let (client1_conn_id, io_task1, client1_incoming) =
358 client1.add_connection(client1_to_server_conn).await;
359 let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
360
361 let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
362 let (client2_conn_id, io_task3, client2_incoming) =
363 client2.add_connection(client2_to_server_conn).await;
364 let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
365
366 executor.spawn(io_task1).detach();
367 executor.spawn(io_task2).detach();
368 executor.spawn(io_task3).detach();
369 executor.spawn(io_task4).detach();
370 executor
371 .spawn(handle_messages(server_incoming1, server.clone()))
372 .detach();
373 executor
374 .spawn(handle_messages(client1_incoming, client1.clone()))
375 .detach();
376 executor
377 .spawn(handle_messages(server_incoming2, server.clone()))
378 .detach();
379 executor
380 .spawn(handle_messages(client2_incoming, client2.clone()))
381 .detach();
382
383 assert_eq!(
384 client1
385 .request(client1_conn_id, proto::Ping {},)
386 .await
387 .unwrap(),
388 proto::Ack {}
389 );
390
391 assert_eq!(
392 client2
393 .request(client2_conn_id, proto::Ping {},)
394 .await
395 .unwrap(),
396 proto::Ack {}
397 );
398
399 assert_eq!(
400 client1
401 .request(
402 client1_conn_id,
403 proto::OpenBuffer {
404 project_id: 0,
405 worktree_id: 1,
406 path: "path/one".to_string(),
407 },
408 )
409 .await
410 .unwrap(),
411 proto::OpenBufferResponse {
412 buffer: Some(proto::Buffer {
413 id: 101,
414 visible_text: "path/one content".to_string(),
415 ..Default::default()
416 }),
417 }
418 );
419
420 assert_eq!(
421 client2
422 .request(
423 client2_conn_id,
424 proto::OpenBuffer {
425 project_id: 0,
426 worktree_id: 2,
427 path: "path/two".to_string(),
428 },
429 )
430 .await
431 .unwrap(),
432 proto::OpenBufferResponse {
433 buffer: Some(proto::Buffer {
434 id: 102,
435 visible_text: "path/two content".to_string(),
436 ..Default::default()
437 }),
438 }
439 );
440
441 client1.disconnect(client1_conn_id);
442 client2.disconnect(client1_conn_id);
443
444 async fn handle_messages(
445 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
446 peer: Arc<Peer>,
447 ) -> Result<()> {
448 while let Some(envelope) = messages.next().await {
449 let envelope = envelope.into_any();
450 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
451 let receipt = envelope.receipt();
452 peer.respond(receipt, proto::Ack {}).await?
453 } else if let Some(envelope) =
454 envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
455 {
456 let message = &envelope.payload;
457 let receipt = envelope.receipt();
458 let response = match message.path.as_str() {
459 "path/one" => {
460 assert_eq!(message.worktree_id, 1);
461 proto::OpenBufferResponse {
462 buffer: Some(proto::Buffer {
463 id: 101,
464 visible_text: "path/one content".to_string(),
465 ..Default::default()
466 }),
467 }
468 }
469 "path/two" => {
470 assert_eq!(message.worktree_id, 2);
471 proto::OpenBufferResponse {
472 buffer: Some(proto::Buffer {
473 id: 102,
474 visible_text: "path/two content".to_string(),
475 ..Default::default()
476 }),
477 }
478 }
479 _ => {
480 panic!("unexpected path {}", message.path);
481 }
482 };
483
484 peer.respond(receipt, response).await?
485 } else {
486 panic!("unknown message type");
487 }
488 }
489
490 Ok(())
491 }
492 }
493
494 #[gpui::test(iterations = 10)]
495 async fn test_order_of_response_and_incoming(cx: TestAppContext) {
496 let executor = cx.foreground();
497 let server = Peer::new();
498 let client = Peer::new();
499
500 let (client_to_server_conn, server_to_client_conn, _) = Connection::in_memory();
501 let (client_to_server_conn_id, io_task1, mut client_incoming) =
502 client.add_connection(client_to_server_conn).await;
503 let (server_to_client_conn_id, io_task2, mut server_incoming) =
504 server.add_connection(server_to_client_conn).await;
505
506 executor.spawn(io_task1).detach();
507 executor.spawn(io_task2).detach();
508
509 executor
510 .spawn(async move {
511 let request = server_incoming
512 .next()
513 .await
514 .unwrap()
515 .into_any()
516 .downcast::<TypedEnvelope<proto::Ping>>()
517 .unwrap();
518
519 server
520 .send(
521 server_to_client_conn_id,
522 proto::Error {
523 message: "message 1".to_string(),
524 },
525 )
526 .await
527 .unwrap();
528 server
529 .send(
530 server_to_client_conn_id,
531 proto::Error {
532 message: "message 2".to_string(),
533 },
534 )
535 .await
536 .unwrap();
537 server
538 .respond(request.receipt(), proto::Ack {})
539 .await
540 .unwrap();
541
542 // Prevent the connection from being dropped
543 server_incoming.next().await;
544 })
545 .detach();
546
547 let events = Arc::new(Mutex::new(Vec::new()));
548
549 let response = client.request(client_to_server_conn_id, proto::Ping {});
550 let response_task = executor.spawn({
551 let events = events.clone();
552 async move {
553 response.await.unwrap();
554 events.lock().push("response".to_string());
555 }
556 });
557
558 executor
559 .spawn({
560 let events = events.clone();
561 async move {
562 let incoming1 = client_incoming
563 .next()
564 .await
565 .unwrap()
566 .into_any()
567 .downcast::<TypedEnvelope<proto::Error>>()
568 .unwrap();
569 events.lock().push(incoming1.payload.message);
570 let incoming2 = client_incoming
571 .next()
572 .await
573 .unwrap()
574 .into_any()
575 .downcast::<TypedEnvelope<proto::Error>>()
576 .unwrap();
577 events.lock().push(incoming2.payload.message);
578
579 // Prevent the connection from being dropped
580 client_incoming.next().await;
581 }
582 })
583 .detach();
584
585 response_task.await;
586 assert_eq!(
587 &*events.lock(),
588 &[
589 "message 1".to_string(),
590 "message 2".to_string(),
591 "response".to_string()
592 ]
593 );
594 }
595
596 #[gpui::test(iterations = 10)]
597 async fn test_disconnect(cx: TestAppContext) {
598 let executor = cx.foreground();
599
600 let (client_conn, mut server_conn, _) = Connection::in_memory();
601
602 let client = Peer::new();
603 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
604
605 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
606 executor
607 .spawn(async move {
608 io_handler.await.ok();
609 io_ended_tx.send(()).await.unwrap();
610 })
611 .detach();
612
613 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
614 executor
615 .spawn(async move {
616 incoming.next().await;
617 messages_ended_tx.send(()).await.unwrap();
618 })
619 .detach();
620
621 client.disconnect(connection_id);
622
623 io_ended_rx.recv().await;
624 messages_ended_rx.recv().await;
625 assert!(server_conn
626 .send(WebSocketMessage::Binary(vec![]))
627 .await
628 .is_err());
629 }
630
631 #[gpui::test(iterations = 10)]
632 async fn test_io_error(cx: TestAppContext) {
633 let executor = cx.foreground();
634 let (client_conn, mut server_conn, _) = Connection::in_memory();
635
636 let client = Peer::new();
637 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
638 executor.spawn(io_handler).detach();
639 executor
640 .spawn(async move { incoming.next().await })
641 .detach();
642
643 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
644 let _request = server_conn.rx.next().await.unwrap().unwrap();
645
646 drop(server_conn);
647 assert_eq!(
648 response.await.unwrap_err().to_string(),
649 "connection was closed"
650 );
651 }
652}