1use crate::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Context, Result};
3use async_lock::{Mutex, RwLock};
4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
5use futures::{FutureExt, StreamExt};
6use postage::{
7 mpsc,
8 prelude::{Sink as _, Stream as _},
9};
10use std::{
11 collections::HashMap,
12 fmt,
13 future::Future,
14 marker::PhantomData,
15 sync::{
16 atomic::{self, AtomicU32},
17 Arc,
18 },
19};
20
21#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
22pub struct ConnectionId(pub u32);
23
24impl fmt::Display for ConnectionId {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 self.0.fmt(f)
27 }
28}
29
30#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
31pub struct PeerId(pub u32);
32
33impl fmt::Display for PeerId {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 self.0.fmt(f)
36 }
37}
38
39pub struct Receipt<T> {
40 pub sender_id: ConnectionId,
41 pub message_id: u32,
42 payload_type: PhantomData<T>,
43}
44
45impl<T> Clone for Receipt<T> {
46 fn clone(&self) -> Self {
47 Self {
48 sender_id: self.sender_id,
49 message_id: self.message_id,
50 payload_type: PhantomData,
51 }
52 }
53}
54
55impl<T> Copy for Receipt<T> {}
56
57pub struct TypedEnvelope<T> {
58 pub sender_id: ConnectionId,
59 pub original_sender_id: Option<PeerId>,
60 pub message_id: u32,
61 pub payload: T,
62}
63
64impl<T> TypedEnvelope<T> {
65 pub fn original_sender_id(&self) -> Result<PeerId> {
66 self.original_sender_id
67 .ok_or_else(|| anyhow!("missing original_sender_id"))
68 }
69}
70
71impl<T: RequestMessage> TypedEnvelope<T> {
72 pub fn receipt(&self) -> Receipt<T> {
73 Receipt {
74 sender_id: self.sender_id,
75 message_id: self.message_id,
76 payload_type: PhantomData,
77 }
78 }
79}
80
81pub struct Peer {
82 connections: RwLock<HashMap<ConnectionId, Connection>>,
83 next_connection_id: AtomicU32,
84}
85
86#[derive(Clone)]
87struct Connection {
88 outgoing_tx: mpsc::Sender<proto::Envelope>,
89 next_message_id: Arc<AtomicU32>,
90 response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
91}
92
93impl Peer {
94 pub fn new() -> Arc<Self> {
95 Arc::new(Self {
96 connections: Default::default(),
97 next_connection_id: Default::default(),
98 })
99 }
100
101 pub async fn add_connection<Conn>(
102 self: &Arc<Self>,
103 conn: Conn,
104 ) -> (
105 ConnectionId,
106 impl Future<Output = anyhow::Result<()>> + Send,
107 mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
108 )
109 where
110 Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
111 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
112 + Send
113 + Unpin,
114 {
115 let (tx, rx) = conn.split();
116 let connection_id = ConnectionId(
117 self.next_connection_id
118 .fetch_add(1, atomic::Ordering::SeqCst),
119 );
120 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
121 let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
122 let connection = Connection {
123 outgoing_tx,
124 next_message_id: Default::default(),
125 response_channels: Default::default(),
126 };
127 let mut writer = MessageStream::new(tx);
128 let mut reader = MessageStream::new(rx);
129
130 let response_channels = connection.response_channels.clone();
131 let handle_io = async move {
132 loop {
133 let read_message = reader.read_message().fuse();
134 futures::pin_mut!(read_message);
135 loop {
136 futures::select_biased! {
137 incoming = read_message => match incoming {
138 Ok(incoming) => {
139 if let Some(responding_to) = incoming.responding_to {
140 let channel = response_channels.lock().await.remove(&responding_to);
141 if let Some(mut tx) = channel {
142 tx.send(incoming).await.ok();
143 } else {
144 log::warn!("received RPC response to unknown request {}", responding_to);
145 }
146 } else {
147 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
148 if incoming_tx.send(envelope).await.is_err() {
149 response_channels.lock().await.clear();
150 return Ok(())
151 }
152 } else {
153 log::error!("unable to construct a typed envelope");
154 }
155 }
156
157 break;
158 }
159 Err(error) => {
160 response_channels.lock().await.clear();
161 Err(error).context("received invalid RPC message")?;
162 }
163 },
164 outgoing = outgoing_rx.recv().fuse() => match outgoing {
165 Some(outgoing) => {
166 if let Err(result) = writer.write_message(&outgoing).await {
167 response_channels.lock().await.clear();
168 Err(result).context("failed to write RPC message")?;
169 }
170 }
171 None => {
172 response_channels.lock().await.clear();
173 return Ok(())
174 }
175 }
176 }
177 }
178 }
179 };
180
181 self.connections
182 .write()
183 .await
184 .insert(connection_id, connection);
185
186 (connection_id, handle_io, incoming_rx)
187 }
188
189 pub async fn disconnect(&self, connection_id: ConnectionId) {
190 self.connections.write().await.remove(&connection_id);
191 }
192
193 pub async fn reset(&self) {
194 self.connections.write().await.clear();
195 }
196
197 pub fn request<T: RequestMessage>(
198 self: &Arc<Self>,
199 receiver_id: ConnectionId,
200 request: T,
201 ) -> impl Future<Output = Result<T::Response>> {
202 self.request_internal(None, receiver_id, request)
203 }
204
205 pub fn forward_request<T: RequestMessage>(
206 self: &Arc<Self>,
207 sender_id: ConnectionId,
208 receiver_id: ConnectionId,
209 request: T,
210 ) -> impl Future<Output = Result<T::Response>> {
211 self.request_internal(Some(sender_id), receiver_id, request)
212 }
213
214 pub fn request_internal<T: RequestMessage>(
215 self: &Arc<Self>,
216 original_sender_id: Option<ConnectionId>,
217 receiver_id: ConnectionId,
218 request: T,
219 ) -> impl Future<Output = Result<T::Response>> {
220 let this = self.clone();
221 let (tx, mut rx) = mpsc::channel(1);
222 async move {
223 let mut connection = this.connection(receiver_id).await?;
224 let message_id = connection
225 .next_message_id
226 .fetch_add(1, atomic::Ordering::SeqCst);
227 connection
228 .response_channels
229 .lock()
230 .await
231 .insert(message_id, tx);
232 connection
233 .outgoing_tx
234 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
235 .await
236 .map_err(|_| anyhow!("connection was closed"))?;
237 let response = rx
238 .recv()
239 .await
240 .ok_or_else(|| anyhow!("connection was closed"))?;
241 T::Response::from_envelope(response)
242 .ok_or_else(|| anyhow!("received response of the wrong type"))
243 }
244 }
245
246 pub fn send<T: EnvelopedMessage>(
247 self: &Arc<Self>,
248 receiver_id: ConnectionId,
249 message: T,
250 ) -> impl Future<Output = Result<()>> {
251 let this = self.clone();
252 async move {
253 let mut connection = this.connection(receiver_id).await?;
254 let message_id = connection
255 .next_message_id
256 .fetch_add(1, atomic::Ordering::SeqCst);
257 connection
258 .outgoing_tx
259 .send(message.into_envelope(message_id, None, None))
260 .await?;
261 Ok(())
262 }
263 }
264
265 pub fn forward_send<T: EnvelopedMessage>(
266 self: &Arc<Self>,
267 sender_id: ConnectionId,
268 receiver_id: ConnectionId,
269 message: T,
270 ) -> impl Future<Output = Result<()>> {
271 let this = self.clone();
272 async move {
273 let mut connection = this.connection(receiver_id).await?;
274 let message_id = connection
275 .next_message_id
276 .fetch_add(1, atomic::Ordering::SeqCst);
277 connection
278 .outgoing_tx
279 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
280 .await?;
281 Ok(())
282 }
283 }
284
285 pub fn respond<T: RequestMessage>(
286 self: &Arc<Self>,
287 receipt: Receipt<T>,
288 response: T::Response,
289 ) -> impl Future<Output = Result<()>> {
290 let this = self.clone();
291 async move {
292 let mut connection = this.connection(receipt.sender_id).await?;
293 let message_id = connection
294 .next_message_id
295 .fetch_add(1, atomic::Ordering::SeqCst);
296 connection
297 .outgoing_tx
298 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
299 .await?;
300 Ok(())
301 }
302 }
303
304 fn connection(
305 self: &Arc<Self>,
306 connection_id: ConnectionId,
307 ) -> impl Future<Output = Result<Connection>> {
308 let this = self.clone();
309 async move {
310 let connections = this.connections.read().await;
311 let connection = connections
312 .get(&connection_id)
313 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
314 Ok(connection.clone())
315 }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use crate::{test, TypedEnvelope};
323
324 #[test]
325 fn test_request_response() {
326 smol::block_on(async move {
327 // create 2 clients connected to 1 server
328 let server = Peer::new();
329 let client1 = Peer::new();
330 let client2 = Peer::new();
331
332 let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
333 let (client1_conn_id, io_task1, _) =
334 client1.add_connection(client1_to_server_conn).await;
335 let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
336
337 let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
338 let (client2_conn_id, io_task3, _) =
339 client2.add_connection(client2_to_server_conn).await;
340 let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
341
342 smol::spawn(io_task1).detach();
343 smol::spawn(io_task2).detach();
344 smol::spawn(io_task3).detach();
345 smol::spawn(io_task4).detach();
346 smol::spawn(handle_messages(incoming1, server.clone())).detach();
347 smol::spawn(handle_messages(incoming2, server.clone())).detach();
348
349 assert_eq!(
350 client1
351 .request(client1_conn_id, proto::Ping { id: 1 },)
352 .await
353 .unwrap(),
354 proto::Pong { id: 1 }
355 );
356
357 assert_eq!(
358 client2
359 .request(client2_conn_id, proto::Ping { id: 2 },)
360 .await
361 .unwrap(),
362 proto::Pong { id: 2 }
363 );
364
365 assert_eq!(
366 client1
367 .request(
368 client1_conn_id,
369 proto::OpenBuffer {
370 worktree_id: 1,
371 path: "path/one".to_string(),
372 },
373 )
374 .await
375 .unwrap(),
376 proto::OpenBufferResponse {
377 buffer: Some(proto::Buffer {
378 id: 101,
379 content: "path/one content".to_string(),
380 history: vec![],
381 selections: vec![],
382 }),
383 }
384 );
385
386 assert_eq!(
387 client2
388 .request(
389 client2_conn_id,
390 proto::OpenBuffer {
391 worktree_id: 2,
392 path: "path/two".to_string(),
393 },
394 )
395 .await
396 .unwrap(),
397 proto::OpenBufferResponse {
398 buffer: Some(proto::Buffer {
399 id: 102,
400 content: "path/two content".to_string(),
401 history: vec![],
402 selections: vec![],
403 }),
404 }
405 );
406
407 client1.disconnect(client1_conn_id).await;
408 client2.disconnect(client1_conn_id).await;
409
410 async fn handle_messages(
411 mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
412 peer: Arc<Peer>,
413 ) -> Result<()> {
414 while let Some(envelope) = messages.next().await {
415 let envelope = envelope.into_any();
416 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
417 let receipt = envelope.receipt();
418 peer.respond(
419 receipt,
420 proto::Pong {
421 id: envelope.payload.id,
422 },
423 )
424 .await?
425 } else if let Some(envelope) =
426 envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
427 {
428 let message = &envelope.payload;
429 let receipt = envelope.receipt();
430 let response = match message.path.as_str() {
431 "path/one" => {
432 assert_eq!(message.worktree_id, 1);
433 proto::OpenBufferResponse {
434 buffer: Some(proto::Buffer {
435 id: 101,
436 content: "path/one content".to_string(),
437 history: vec![],
438 selections: vec![],
439 }),
440 }
441 }
442 "path/two" => {
443 assert_eq!(message.worktree_id, 2);
444 proto::OpenBufferResponse {
445 buffer: Some(proto::Buffer {
446 id: 102,
447 content: "path/two content".to_string(),
448 history: vec![],
449 selections: vec![],
450 }),
451 }
452 }
453 _ => {
454 panic!("unexpected path {}", message.path);
455 }
456 };
457
458 peer.respond(receipt, response).await?
459 } else {
460 panic!("unknown message type");
461 }
462 }
463
464 Ok(())
465 }
466 });
467 }
468
469 #[test]
470 fn test_disconnect() {
471 smol::block_on(async move {
472 let (client_conn, mut server_conn) = test::Channel::bidirectional();
473
474 let client = Peer::new();
475 let (connection_id, io_handler, mut incoming) =
476 client.add_connection(client_conn).await;
477
478 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
479 smol::spawn(async move {
480 io_handler.await.ok();
481 io_ended_tx.send(()).await.unwrap();
482 })
483 .detach();
484
485 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
486 smol::spawn(async move {
487 incoming.next().await;
488 messages_ended_tx.send(()).await.unwrap();
489 })
490 .detach();
491
492 client.disconnect(connection_id).await;
493
494 io_ended_rx.recv().await;
495 messages_ended_rx.recv().await;
496 assert!(
497 futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
498 .await
499 .is_err()
500 );
501 });
502 }
503
504 #[test]
505 fn test_io_error() {
506 smol::block_on(async move {
507 let (client_conn, server_conn) = test::Channel::bidirectional();
508 drop(server_conn);
509
510 let client = Peer::new();
511 let (connection_id, io_handler, mut incoming) =
512 client.add_connection(client_conn).await;
513 smol::spawn(io_handler).detach();
514 smol::spawn(async move { incoming.next().await }).detach();
515
516 let err = client
517 .request(connection_id, proto::Ping { id: 42 })
518 .await
519 .unwrap_err();
520 assert_eq!(err.to_string(), "connection was closed");
521 });
522 }
523}