1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Context, Result};
3use async_lock::{Mutex, RwLock};
4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
5use futures::{FutureExt, StreamExt};
6use postage::{
7 mpsc,
8 prelude::{Sink as _, Stream as _},
9};
10use std::{
11 any::Any,
12 collections::HashMap,
13 fmt,
14 future::Future,
15 marker::PhantomData,
16 sync::{
17 atomic::{self, AtomicU32},
18 Arc,
19 },
20};
21
22#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
23pub struct ConnectionId(pub u32);
24
25impl fmt::Display for ConnectionId {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 self.0.fmt(f)
28 }
29}
30
31#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
32pub struct PeerId(pub u32);
33
34impl fmt::Display for PeerId {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 self.0.fmt(f)
37 }
38}
39
40pub struct Receipt<T> {
41 pub sender_id: ConnectionId,
42 pub message_id: u32,
43 payload_type: PhantomData<T>,
44}
45
46impl<T> Clone for Receipt<T> {
47 fn clone(&self) -> Self {
48 Self {
49 sender_id: self.sender_id,
50 message_id: self.message_id,
51 payload_type: PhantomData,
52 }
53 }
54}
55
56impl<T> Copy for Receipt<T> {}
57
58pub struct TypedEnvelope<T> {
59 pub sender_id: ConnectionId,
60 pub original_sender_id: Option<PeerId>,
61 pub message_id: u32,
62 pub payload: T,
63}
64
65impl<T> TypedEnvelope<T> {
66 pub fn original_sender_id(&self) -> Result<PeerId> {
67 self.original_sender_id
68 .ok_or_else(|| anyhow!("missing original_sender_id"))
69 }
70}
71
72impl<T: RequestMessage> TypedEnvelope<T> {
73 pub fn receipt(&self) -> Receipt<T> {
74 Receipt {
75 sender_id: self.sender_id,
76 message_id: self.message_id,
77 payload_type: PhantomData,
78 }
79 }
80}
81
82pub struct Peer {
83 connections: RwLock<HashMap<ConnectionId, Connection>>,
84 next_connection_id: AtomicU32,
85}
86
87#[derive(Clone)]
88struct Connection {
89 outgoing_tx: mpsc::Sender<proto::Envelope>,
90 next_message_id: Arc<AtomicU32>,
91 response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
92}
93
94impl Peer {
95 pub fn new() -> Arc<Self> {
96 Arc::new(Self {
97 connections: Default::default(),
98 next_connection_id: Default::default(),
99 })
100 }
101
102 pub async fn add_connection<Conn>(
103 self: &Arc<Self>,
104 conn: Conn,
105 ) -> (
106 ConnectionId,
107 impl Future<Output = anyhow::Result<()>> + Send,
108 mpsc::Receiver<Box<dyn Any + Sync + Send>>,
109 )
110 where
111 Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
112 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
113 + Send
114 + Unpin,
115 {
116 let (tx, rx) = conn.split();
117 let connection_id = ConnectionId(
118 self.next_connection_id
119 .fetch_add(1, atomic::Ordering::SeqCst),
120 );
121 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
122 let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
123 let connection = Connection {
124 outgoing_tx,
125 next_message_id: Default::default(),
126 response_channels: Default::default(),
127 };
128 let mut writer = MessageStream::new(tx);
129 let mut reader = MessageStream::new(rx);
130
131 let response_channels = connection.response_channels.clone();
132 let handle_io = async move {
133 loop {
134 let read_message = reader.read_message().fuse();
135 futures::pin_mut!(read_message);
136 loop {
137 futures::select_biased! {
138 incoming = read_message => match incoming {
139 Ok(incoming) => {
140 if let Some(responding_to) = incoming.responding_to {
141 let channel = response_channels.lock().await.remove(&responding_to);
142 if let Some(mut tx) = channel {
143 tx.send(incoming).await.ok();
144 } else {
145 log::warn!("received RPC response to unknown request {}", responding_to);
146 }
147 } else {
148 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
149 if incoming_tx.send(envelope).await.is_err() {
150 response_channels.lock().await.clear();
151 return Ok(())
152 }
153 } else {
154 log::error!("unable to construct a typed envelope");
155 }
156 }
157
158 break;
159 }
160 Err(error) => {
161 response_channels.lock().await.clear();
162 Err(error).context("received invalid RPC message")?;
163 }
164 },
165 outgoing = outgoing_rx.recv().fuse() => match outgoing {
166 Some(outgoing) => {
167 if let Err(result) = writer.write_message(&outgoing).await {
168 response_channels.lock().await.clear();
169 Err(result).context("failed to write RPC message")?;
170 }
171 }
172 None => {
173 response_channels.lock().await.clear();
174 return Ok(())
175 }
176 }
177 }
178 }
179 }
180 };
181
182 self.connections
183 .write()
184 .await
185 .insert(connection_id, connection);
186
187 (connection_id, handle_io, incoming_rx)
188 }
189
190 pub async fn disconnect(&self, connection_id: ConnectionId) {
191 self.connections.write().await.remove(&connection_id);
192 }
193
194 pub async fn reset(&self) {
195 self.connections.write().await.clear();
196 }
197
198 pub fn request<T: RequestMessage>(
199 self: &Arc<Self>,
200 receiver_id: ConnectionId,
201 request: T,
202 ) -> impl Future<Output = Result<T::Response>> {
203 self.request_internal(None, receiver_id, request)
204 }
205
206 pub fn forward_request<T: RequestMessage>(
207 self: &Arc<Self>,
208 sender_id: ConnectionId,
209 receiver_id: ConnectionId,
210 request: T,
211 ) -> impl Future<Output = Result<T::Response>> {
212 self.request_internal(Some(sender_id), receiver_id, request)
213 }
214
215 pub fn request_internal<T: RequestMessage>(
216 self: &Arc<Self>,
217 original_sender_id: Option<ConnectionId>,
218 receiver_id: ConnectionId,
219 request: T,
220 ) -> impl Future<Output = Result<T::Response>> {
221 let this = self.clone();
222 let (tx, mut rx) = mpsc::channel(1);
223 async move {
224 let mut connection = this.connection(receiver_id).await?;
225 let message_id = connection
226 .next_message_id
227 .fetch_add(1, atomic::Ordering::SeqCst);
228 connection
229 .response_channels
230 .lock()
231 .await
232 .insert(message_id, tx);
233 connection
234 .outgoing_tx
235 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
236 .await
237 .map_err(|_| anyhow!("connection was closed"))?;
238 let response = rx
239 .recv()
240 .await
241 .ok_or_else(|| anyhow!("connection was closed"))?;
242 T::Response::from_envelope(response)
243 .ok_or_else(|| anyhow!("received response of the wrong type"))
244 }
245 }
246
247 pub fn send<T: EnvelopedMessage>(
248 self: &Arc<Self>,
249 receiver_id: ConnectionId,
250 message: T,
251 ) -> impl Future<Output = Result<()>> {
252 let this = self.clone();
253 async move {
254 let mut connection = this.connection(receiver_id).await?;
255 let message_id = connection
256 .next_message_id
257 .fetch_add(1, atomic::Ordering::SeqCst);
258 connection
259 .outgoing_tx
260 .send(message.into_envelope(message_id, None, None))
261 .await?;
262 Ok(())
263 }
264 }
265
266 pub fn forward_send<T: EnvelopedMessage>(
267 self: &Arc<Self>,
268 sender_id: ConnectionId,
269 receiver_id: ConnectionId,
270 message: T,
271 ) -> impl Future<Output = Result<()>> {
272 let this = self.clone();
273 async move {
274 let mut connection = this.connection(receiver_id).await?;
275 let message_id = connection
276 .next_message_id
277 .fetch_add(1, atomic::Ordering::SeqCst);
278 connection
279 .outgoing_tx
280 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
281 .await?;
282 Ok(())
283 }
284 }
285
286 pub fn respond<T: RequestMessage>(
287 self: &Arc<Self>,
288 receipt: Receipt<T>,
289 response: T::Response,
290 ) -> impl Future<Output = Result<()>> {
291 let this = self.clone();
292 async move {
293 let mut connection = this.connection(receipt.sender_id).await?;
294 let message_id = connection
295 .next_message_id
296 .fetch_add(1, atomic::Ordering::SeqCst);
297 connection
298 .outgoing_tx
299 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
300 .await?;
301 Ok(())
302 }
303 }
304
305 fn connection(
306 self: &Arc<Self>,
307 connection_id: ConnectionId,
308 ) -> impl Future<Output = Result<Connection>> {
309 let this = self.clone();
310 async move {
311 let connections = this.connections.read().await;
312 let connection = connections
313 .get(&connection_id)
314 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
315 Ok(connection.clone())
316 }
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::{test, TypedEnvelope};
324
325 #[test]
326 fn test_request_response() {
327 smol::block_on(async move {
328 // create 2 clients connected to 1 server
329 let server = Peer::new();
330 let client1 = Peer::new();
331 let client2 = Peer::new();
332
333 let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
334 let (client1_conn_id, io_task1, _) =
335 client1.add_connection(client1_to_server_conn).await;
336 let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
337
338 let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
339 let (client2_conn_id, io_task3, _) =
340 client2.add_connection(client2_to_server_conn).await;
341 let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
342
343 smol::spawn(io_task1).detach();
344 smol::spawn(io_task2).detach();
345 smol::spawn(io_task3).detach();
346 smol::spawn(io_task4).detach();
347 smol::spawn(handle_messages(incoming1, server.clone())).detach();
348 smol::spawn(handle_messages(incoming2, server.clone())).detach();
349
350 assert_eq!(
351 client1
352 .request(client1_conn_id, proto::Ping { id: 1 },)
353 .await
354 .unwrap(),
355 proto::Pong { id: 1 }
356 );
357
358 assert_eq!(
359 client2
360 .request(client2_conn_id, proto::Ping { id: 2 },)
361 .await
362 .unwrap(),
363 proto::Pong { id: 2 }
364 );
365
366 assert_eq!(
367 client1
368 .request(
369 client1_conn_id,
370 proto::OpenBuffer {
371 worktree_id: 1,
372 path: "path/one".to_string(),
373 },
374 )
375 .await
376 .unwrap(),
377 proto::OpenBufferResponse {
378 buffer: Some(proto::Buffer {
379 id: 101,
380 content: "path/one content".to_string(),
381 history: vec![],
382 selections: vec![],
383 }),
384 }
385 );
386
387 assert_eq!(
388 client2
389 .request(
390 client2_conn_id,
391 proto::OpenBuffer {
392 worktree_id: 2,
393 path: "path/two".to_string(),
394 },
395 )
396 .await
397 .unwrap(),
398 proto::OpenBufferResponse {
399 buffer: Some(proto::Buffer {
400 id: 102,
401 content: "path/two content".to_string(),
402 history: vec![],
403 selections: vec![],
404 }),
405 }
406 );
407
408 client1.disconnect(client1_conn_id).await;
409 client2.disconnect(client1_conn_id).await;
410
411 async fn handle_messages(
412 mut messages: mpsc::Receiver<Box<dyn Any + Sync + Send>>,
413 peer: Arc<Peer>,
414 ) -> Result<()> {
415 while let Some(envelope) = messages.next().await {
416 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
417 let receipt = envelope.receipt();
418 peer.respond(
419 receipt,
420 proto::Pong {
421 id: envelope.payload.id,
422 },
423 )
424 .await?
425 } else if let Some(envelope) =
426 envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
427 {
428 let message = &envelope.payload;
429 let receipt = envelope.receipt();
430 let response = match message.path.as_str() {
431 "path/one" => {
432 assert_eq!(message.worktree_id, 1);
433 proto::OpenBufferResponse {
434 buffer: Some(proto::Buffer {
435 id: 101,
436 content: "path/one content".to_string(),
437 history: vec![],
438 selections: vec![],
439 }),
440 }
441 }
442 "path/two" => {
443 assert_eq!(message.worktree_id, 2);
444 proto::OpenBufferResponse {
445 buffer: Some(proto::Buffer {
446 id: 102,
447 content: "path/two content".to_string(),
448 history: vec![],
449 selections: vec![],
450 }),
451 }
452 }
453 _ => {
454 panic!("unexpected path {}", message.path);
455 }
456 };
457
458 peer.respond(receipt, response).await?
459 } else {
460 panic!("unknown message type");
461 }
462 }
463
464 Ok(())
465 }
466 });
467 }
468
469 #[test]
470 fn test_disconnect() {
471 smol::block_on(async move {
472 let (client_conn, mut server_conn) = test::Channel::bidirectional();
473
474 let client = Peer::new();
475 let (connection_id, io_handler, mut incoming) =
476 client.add_connection(client_conn).await;
477
478 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
479 smol::spawn(async move {
480 io_handler.await.ok();
481 io_ended_tx.send(()).await.unwrap();
482 })
483 .detach();
484
485 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
486 smol::spawn(async move {
487 incoming.next().await;
488 messages_ended_tx.send(()).await.unwrap();
489 })
490 .detach();
491
492 client.disconnect(connection_id).await;
493
494 io_ended_rx.recv().await;
495 messages_ended_rx.recv().await;
496 assert!(
497 futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
498 .await
499 .is_err()
500 );
501 });
502 }
503
504 #[test]
505 fn test_io_error() {
506 smol::block_on(async move {
507 let (client_conn, server_conn) = test::Channel::bidirectional();
508 drop(server_conn);
509
510 let client = Peer::new();
511 let (connection_id, io_handler, mut incoming) =
512 client.add_connection(client_conn).await;
513 smol::spawn(io_handler).detach();
514 smol::spawn(async move { incoming.next().await }).detach();
515
516 let err = client
517 .request(
518 connection_id,
519 proto::Auth {
520 user_id: 42,
521 access_token: "token".to_string(),
522 },
523 )
524 .await
525 .unwrap_err();
526 assert_eq!(err.to_string(), "connection was closed");
527 });
528 }
529}