1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use async_lock::{Mutex, RwLock};
5use futures::FutureExt as _;
6use postage::{
7 mpsc,
8 prelude::{Sink as _, Stream as _},
9};
10use std::{
11 collections::HashMap,
12 fmt,
13 future::Future,
14 marker::PhantomData,
15 sync::{
16 atomic::{self, AtomicU32},
17 Arc,
18 },
19};
20
21#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
22pub struct ConnectionId(pub u32);
23
24impl fmt::Display for ConnectionId {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 self.0.fmt(f)
27 }
28}
29
30#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
31pub struct PeerId(pub u32);
32
33impl fmt::Display for PeerId {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 self.0.fmt(f)
36 }
37}
38
39pub struct Receipt<T> {
40 pub sender_id: ConnectionId,
41 pub message_id: u32,
42 payload_type: PhantomData<T>,
43}
44
45impl<T> Clone for Receipt<T> {
46 fn clone(&self) -> Self {
47 Self {
48 sender_id: self.sender_id,
49 message_id: self.message_id,
50 payload_type: PhantomData,
51 }
52 }
53}
54
55impl<T> Copy for Receipt<T> {}
56
57pub struct TypedEnvelope<T> {
58 pub sender_id: ConnectionId,
59 pub original_sender_id: Option<PeerId>,
60 pub message_id: u32,
61 pub payload: T,
62}
63
64impl<T> TypedEnvelope<T> {
65 pub fn original_sender_id(&self) -> Result<PeerId> {
66 self.original_sender_id
67 .ok_or_else(|| anyhow!("missing original_sender_id"))
68 }
69}
70
71impl<T: RequestMessage> TypedEnvelope<T> {
72 pub fn receipt(&self) -> Receipt<T> {
73 Receipt {
74 sender_id: self.sender_id,
75 message_id: self.message_id,
76 payload_type: PhantomData,
77 }
78 }
79}
80
81pub struct Peer {
82 connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
83 next_connection_id: AtomicU32,
84}
85
86#[derive(Clone)]
87struct ConnectionState {
88 outgoing_tx: mpsc::Sender<proto::Envelope>,
89 next_message_id: Arc<AtomicU32>,
90 response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
91}
92
93impl Peer {
94 pub fn new() -> Arc<Self> {
95 Arc::new(Self {
96 connections: Default::default(),
97 next_connection_id: Default::default(),
98 })
99 }
100
101 pub async fn add_connection(
102 self: &Arc<Self>,
103 connection: Connection,
104 ) -> (
105 ConnectionId,
106 impl Future<Output = anyhow::Result<()>> + Send,
107 mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
108 ) {
109 let connection_id = ConnectionId(
110 self.next_connection_id
111 .fetch_add(1, atomic::Ordering::SeqCst),
112 );
113 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
114 let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
115 let connection_state = ConnectionState {
116 outgoing_tx,
117 next_message_id: Default::default(),
118 response_channels: Default::default(),
119 };
120 let mut writer = MessageStream::new(connection.tx);
121 let mut reader = MessageStream::new(connection.rx);
122
123 let this = self.clone();
124 let response_channels = connection_state.response_channels.clone();
125 let handle_io = async move {
126 loop {
127 let read_message = reader.read_message().fuse();
128 futures::pin_mut!(read_message);
129 loop {
130 futures::select_biased! {
131 incoming = read_message => match incoming {
132 Ok(incoming) => {
133 if let Some(responding_to) = incoming.responding_to {
134 let channel = response_channels.lock().await.remove(&responding_to);
135 if let Some(mut tx) = channel {
136 tx.send(incoming).await.ok();
137 } else {
138 log::warn!("received RPC response to unknown request {}", responding_to);
139 }
140 } else {
141 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
142 if incoming_tx.send(envelope).await.is_err() {
143 response_channels.lock().await.clear();
144 this.connections.write().await.remove(&connection_id);
145 return Ok(())
146 }
147 } else {
148 log::error!("unable to construct a typed envelope");
149 }
150 }
151
152 break;
153 }
154 Err(error) => {
155 response_channels.lock().await.clear();
156 this.connections.write().await.remove(&connection_id);
157 Err(error).context("received invalid RPC message")?;
158 }
159 },
160 outgoing = outgoing_rx.recv().fuse() => match outgoing {
161 Some(outgoing) => {
162 if let Err(result) = writer.write_message(&outgoing).await {
163 response_channels.lock().await.clear();
164 this.connections.write().await.remove(&connection_id);
165 Err(result).context("failed to write RPC message")?;
166 }
167 }
168 None => {
169 response_channels.lock().await.clear();
170 this.connections.write().await.remove(&connection_id);
171 return Ok(())
172 }
173 }
174 }
175 }
176 }
177 };
178
179 self.connections
180 .write()
181 .await
182 .insert(connection_id, connection_state);
183
184 (connection_id, handle_io, incoming_rx)
185 }
186
187 pub async fn disconnect(&self, connection_id: ConnectionId) {
188 self.connections.write().await.remove(&connection_id);
189 }
190
191 pub async fn reset(&self) {
192 self.connections.write().await.clear();
193 }
194
195 pub fn request<T: RequestMessage>(
196 self: &Arc<Self>,
197 receiver_id: ConnectionId,
198 request: T,
199 ) -> impl Future<Output = Result<T::Response>> {
200 self.request_internal(None, receiver_id, request)
201 }
202
203 pub fn forward_request<T: RequestMessage>(
204 self: &Arc<Self>,
205 sender_id: ConnectionId,
206 receiver_id: ConnectionId,
207 request: T,
208 ) -> impl Future<Output = Result<T::Response>> {
209 self.request_internal(Some(sender_id), receiver_id, request)
210 }
211
212 pub fn request_internal<T: RequestMessage>(
213 self: &Arc<Self>,
214 original_sender_id: Option<ConnectionId>,
215 receiver_id: ConnectionId,
216 request: T,
217 ) -> impl Future<Output = Result<T::Response>> {
218 let this = self.clone();
219 let (tx, mut rx) = mpsc::channel(1);
220 async move {
221 let mut connection = this.connection_state(receiver_id).await?;
222 let message_id = connection
223 .next_message_id
224 .fetch_add(1, atomic::Ordering::SeqCst);
225 connection
226 .response_channels
227 .lock()
228 .await
229 .insert(message_id, tx);
230 connection
231 .outgoing_tx
232 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
233 .await
234 .map_err(|_| anyhow!("connection was closed"))?;
235 let response = rx
236 .recv()
237 .await
238 .ok_or_else(|| anyhow!("connection was closed"))?;
239 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
240 Err(anyhow!("request failed").context(error.message.clone()))
241 } else {
242 T::Response::from_envelope(response)
243 .ok_or_else(|| anyhow!("received response of the wrong type"))
244 }
245 }
246 }
247
248 pub fn send<T: EnvelopedMessage>(
249 self: &Arc<Self>,
250 receiver_id: ConnectionId,
251 message: T,
252 ) -> impl Future<Output = Result<()>> {
253 let this = self.clone();
254 async move {
255 let mut connection = this.connection_state(receiver_id).await?;
256 let message_id = connection
257 .next_message_id
258 .fetch_add(1, atomic::Ordering::SeqCst);
259 connection
260 .outgoing_tx
261 .send(message.into_envelope(message_id, None, None))
262 .await?;
263 Ok(())
264 }
265 }
266
267 pub fn forward_send<T: EnvelopedMessage>(
268 self: &Arc<Self>,
269 sender_id: ConnectionId,
270 receiver_id: ConnectionId,
271 message: T,
272 ) -> impl Future<Output = Result<()>> {
273 let this = self.clone();
274 async move {
275 let mut connection = this.connection_state(receiver_id).await?;
276 let message_id = connection
277 .next_message_id
278 .fetch_add(1, atomic::Ordering::SeqCst);
279 connection
280 .outgoing_tx
281 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
282 .await?;
283 Ok(())
284 }
285 }
286
287 pub fn respond<T: RequestMessage>(
288 self: &Arc<Self>,
289 receipt: Receipt<T>,
290 response: T::Response,
291 ) -> impl Future<Output = Result<()>> {
292 let this = self.clone();
293 async move {
294 let mut connection = this.connection_state(receipt.sender_id).await?;
295 let message_id = connection
296 .next_message_id
297 .fetch_add(1, atomic::Ordering::SeqCst);
298 connection
299 .outgoing_tx
300 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
301 .await?;
302 Ok(())
303 }
304 }
305
306 pub fn respond_with_error<T: RequestMessage>(
307 self: &Arc<Self>,
308 receipt: Receipt<T>,
309 response: proto::Error,
310 ) -> impl Future<Output = Result<()>> {
311 let this = self.clone();
312 async move {
313 let mut connection = this.connection_state(receipt.sender_id).await?;
314 let message_id = connection
315 .next_message_id
316 .fetch_add(1, atomic::Ordering::SeqCst);
317 connection
318 .outgoing_tx
319 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
320 .await?;
321 Ok(())
322 }
323 }
324
325 fn connection_state(
326 self: &Arc<Self>,
327 connection_id: ConnectionId,
328 ) -> impl Future<Output = Result<ConnectionState>> {
329 let this = self.clone();
330 async move {
331 let connections = this.connections.read().await;
332 let connection = connections
333 .get(&connection_id)
334 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
335 Ok(connection.clone())
336 }
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use crate::TypedEnvelope;
344 use async_tungstenite::tungstenite::Message as WebSocketMessage;
345 use futures::StreamExt as _;
346
347 #[test]
348 fn test_request_response() {
349 smol::block_on(async move {
350 // create 2 clients connected to 1 server
351 let server = Peer::new();
352 let client1 = Peer::new();
353 let client2 = Peer::new();
354
355 let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
356 let (client1_conn_id, io_task1, _) =
357 client1.add_connection(client1_to_server_conn).await;
358 let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
359
360 let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
361 let (client2_conn_id, io_task3, _) =
362 client2.add_connection(client2_to_server_conn).await;
363 let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
364
365 smol::spawn(io_task1).detach();
366 smol::spawn(io_task2).detach();
367 smol::spawn(io_task3).detach();
368 smol::spawn(io_task4).detach();
369 smol::spawn(handle_messages(incoming1, server.clone())).detach();
370 smol::spawn(handle_messages(incoming2, server.clone())).detach();
371
372 assert_eq!(
373 client1
374 .request(client1_conn_id, proto::Ping {},)
375 .await
376 .unwrap(),
377 proto::Ack {}
378 );
379
380 assert_eq!(
381 client2
382 .request(client2_conn_id, proto::Ping {},)
383 .await
384 .unwrap(),
385 proto::Ack {}
386 );
387
388 assert_eq!(
389 client1
390 .request(
391 client1_conn_id,
392 proto::OpenBuffer {
393 worktree_id: 1,
394 path: "path/one".to_string(),
395 },
396 )
397 .await
398 .unwrap(),
399 proto::OpenBufferResponse {
400 buffer: Some(proto::Buffer {
401 id: 101,
402 content: "path/one content".to_string(),
403 history: vec![],
404 selections: vec![],
405 }),
406 }
407 );
408
409 assert_eq!(
410 client2
411 .request(
412 client2_conn_id,
413 proto::OpenBuffer {
414 worktree_id: 2,
415 path: "path/two".to_string(),
416 },
417 )
418 .await
419 .unwrap(),
420 proto::OpenBufferResponse {
421 buffer: Some(proto::Buffer {
422 id: 102,
423 content: "path/two content".to_string(),
424 history: vec![],
425 selections: vec![],
426 }),
427 }
428 );
429
430 client1.disconnect(client1_conn_id).await;
431 client2.disconnect(client1_conn_id).await;
432
433 async fn handle_messages(
434 mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
435 peer: Arc<Peer>,
436 ) -> Result<()> {
437 while let Some(envelope) = messages.next().await {
438 let envelope = envelope.into_any();
439 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
440 let receipt = envelope.receipt();
441 peer.respond(receipt, proto::Ack {}).await?
442 } else if let Some(envelope) =
443 envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
444 {
445 let message = &envelope.payload;
446 let receipt = envelope.receipt();
447 let response = match message.path.as_str() {
448 "path/one" => {
449 assert_eq!(message.worktree_id, 1);
450 proto::OpenBufferResponse {
451 buffer: Some(proto::Buffer {
452 id: 101,
453 content: "path/one content".to_string(),
454 history: vec![],
455 selections: vec![],
456 }),
457 }
458 }
459 "path/two" => {
460 assert_eq!(message.worktree_id, 2);
461 proto::OpenBufferResponse {
462 buffer: Some(proto::Buffer {
463 id: 102,
464 content: "path/two content".to_string(),
465 history: vec![],
466 selections: vec![],
467 }),
468 }
469 }
470 _ => {
471 panic!("unexpected path {}", message.path);
472 }
473 };
474
475 peer.respond(receipt, response).await?
476 } else {
477 panic!("unknown message type");
478 }
479 }
480
481 Ok(())
482 }
483 });
484 }
485
486 #[test]
487 fn test_disconnect() {
488 smol::block_on(async move {
489 let (client_conn, mut server_conn, _) = Connection::in_memory();
490
491 let client = Peer::new();
492 let (connection_id, io_handler, mut incoming) =
493 client.add_connection(client_conn).await;
494
495 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
496 smol::spawn(async move {
497 io_handler.await.ok();
498 io_ended_tx.send(()).await.unwrap();
499 })
500 .detach();
501
502 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
503 smol::spawn(async move {
504 incoming.next().await;
505 messages_ended_tx.send(()).await.unwrap();
506 })
507 .detach();
508
509 client.disconnect(connection_id).await;
510
511 io_ended_rx.recv().await;
512 messages_ended_rx.recv().await;
513 assert!(server_conn
514 .send(WebSocketMessage::Binary(vec![]))
515 .await
516 .is_err());
517 });
518 }
519
520 #[test]
521 fn test_io_error() {
522 smol::block_on(async move {
523 let (client_conn, server_conn, _) = Connection::in_memory();
524 drop(server_conn);
525
526 let client = Peer::new();
527 let (connection_id, io_handler, mut incoming) =
528 client.add_connection(client_conn).await;
529 smol::spawn(io_handler).detach();
530 smol::spawn(async move { incoming.next().await }).detach();
531
532 let err = client
533 .request(connection_id, proto::Ping {})
534 .await
535 .unwrap_err();
536 assert_eq!(err.to_string(), "connection was closed");
537 });
538 }
539}