1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use async_lock::{Mutex, RwLock};
5use futures::FutureExt as _;
6use postage::{
7 mpsc,
8 prelude::{Sink as _, Stream as _},
9};
10use std::{
11 collections::HashMap,
12 fmt,
13 future::Future,
14 marker::PhantomData,
15 sync::{
16 atomic::{self, AtomicU32},
17 Arc,
18 },
19};
20
21#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
22pub struct ConnectionId(pub u32);
23
24impl fmt::Display for ConnectionId {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 self.0.fmt(f)
27 }
28}
29
30#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
31pub struct PeerId(pub u32);
32
33impl fmt::Display for PeerId {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 self.0.fmt(f)
36 }
37}
38
39pub struct Receipt<T> {
40 pub sender_id: ConnectionId,
41 pub message_id: u32,
42 payload_type: PhantomData<T>,
43}
44
45impl<T> Clone for Receipt<T> {
46 fn clone(&self) -> Self {
47 Self {
48 sender_id: self.sender_id,
49 message_id: self.message_id,
50 payload_type: PhantomData,
51 }
52 }
53}
54
55impl<T> Copy for Receipt<T> {}
56
57pub struct TypedEnvelope<T> {
58 pub sender_id: ConnectionId,
59 pub original_sender_id: Option<PeerId>,
60 pub message_id: u32,
61 pub payload: T,
62}
63
64impl<T> TypedEnvelope<T> {
65 pub fn original_sender_id(&self) -> Result<PeerId> {
66 self.original_sender_id
67 .ok_or_else(|| anyhow!("missing original_sender_id"))
68 }
69}
70
71impl<T: RequestMessage> TypedEnvelope<T> {
72 pub fn receipt(&self) -> Receipt<T> {
73 Receipt {
74 sender_id: self.sender_id,
75 message_id: self.message_id,
76 payload_type: PhantomData,
77 }
78 }
79}
80
81pub struct Peer {
82 connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
83 next_connection_id: AtomicU32,
84}
85
86#[derive(Clone)]
87struct ConnectionState {
88 outgoing_tx: mpsc::Sender<proto::Envelope>,
89 next_message_id: Arc<AtomicU32>,
90 response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
91}
92
93impl Peer {
94 pub fn new() -> Arc<Self> {
95 Arc::new(Self {
96 connections: Default::default(),
97 next_connection_id: Default::default(),
98 })
99 }
100
101 pub async fn add_connection(
102 self: &Arc<Self>,
103 connection: Connection,
104 ) -> (
105 ConnectionId,
106 impl Future<Output = anyhow::Result<()>> + Send,
107 mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
108 ) {
109 let connection_id = ConnectionId(
110 self.next_connection_id
111 .fetch_add(1, atomic::Ordering::SeqCst),
112 );
113 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
114 let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
115 let connection_state = ConnectionState {
116 outgoing_tx,
117 next_message_id: Default::default(),
118 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
119 };
120 let mut writer = MessageStream::new(connection.tx);
121 let mut reader = MessageStream::new(connection.rx);
122
123 let this = self.clone();
124 let response_channels = connection_state.response_channels.clone();
125 let handle_io = async move {
126 let result = 'outer: loop {
127 let read_message = reader.read_message().fuse();
128 futures::pin_mut!(read_message);
129 loop {
130 futures::select_biased! {
131 incoming = read_message => match incoming {
132 Ok(incoming) => {
133 if let Some(responding_to) = incoming.responding_to {
134 let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to);
135 if let Some(mut tx) = channel {
136 tx.send(incoming).await.ok();
137 } else {
138 log::warn!("received RPC response to unknown request {}", responding_to);
139 }
140 } else {
141 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
142 if incoming_tx.send(envelope).await.is_err() {
143 break 'outer Ok(())
144 }
145 } else {
146 log::error!("unable to construct a typed envelope");
147 }
148 }
149
150 break;
151 }
152 Err(error) => {
153 break 'outer Err(error).context("received invalid RPC message")
154 }
155 },
156 outgoing = outgoing_rx.recv().fuse() => match outgoing {
157 Some(outgoing) => {
158 if let Err(result) = writer.write_message(&outgoing).await {
159 break 'outer Err(result).context("failed to write RPC message")
160 }
161 }
162 None => break 'outer Ok(()),
163 }
164 }
165 }
166 };
167
168 response_channels.lock().await.take();
169 this.connections.write().await.remove(&connection_id);
170 result
171 };
172
173 self.connections
174 .write()
175 .await
176 .insert(connection_id, connection_state);
177
178 (connection_id, handle_io, incoming_rx)
179 }
180
181 pub async fn disconnect(&self, connection_id: ConnectionId) {
182 self.connections.write().await.remove(&connection_id);
183 }
184
185 pub async fn reset(&self) {
186 self.connections.write().await.clear();
187 }
188
189 pub fn request<T: RequestMessage>(
190 self: &Arc<Self>,
191 receiver_id: ConnectionId,
192 request: T,
193 ) -> impl Future<Output = Result<T::Response>> {
194 self.request_internal(None, receiver_id, request)
195 }
196
197 pub fn forward_request<T: RequestMessage>(
198 self: &Arc<Self>,
199 sender_id: ConnectionId,
200 receiver_id: ConnectionId,
201 request: T,
202 ) -> impl Future<Output = Result<T::Response>> {
203 self.request_internal(Some(sender_id), receiver_id, request)
204 }
205
206 pub fn request_internal<T: RequestMessage>(
207 self: &Arc<Self>,
208 original_sender_id: Option<ConnectionId>,
209 receiver_id: ConnectionId,
210 request: T,
211 ) -> impl Future<Output = Result<T::Response>> {
212 let this = self.clone();
213 let (tx, mut rx) = mpsc::channel(1);
214 async move {
215 let mut connection = this.connection_state(receiver_id).await?;
216 let message_id = connection
217 .next_message_id
218 .fetch_add(1, atomic::Ordering::SeqCst);
219 connection
220 .response_channels
221 .lock()
222 .await
223 .as_mut()
224 .ok_or_else(|| anyhow!("connection was closed"))?
225 .insert(message_id, tx);
226 connection
227 .outgoing_tx
228 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
229 .await
230 .map_err(|_| anyhow!("connection was closed"))?;
231 let response = rx
232 .recv()
233 .await
234 .ok_or_else(|| anyhow!("connection was closed"))?;
235 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
236 Err(anyhow!("request failed").context(error.message.clone()))
237 } else {
238 T::Response::from_envelope(response)
239 .ok_or_else(|| anyhow!("received response of the wrong type"))
240 }
241 }
242 }
243
244 pub fn send<T: EnvelopedMessage>(
245 self: &Arc<Self>,
246 receiver_id: ConnectionId,
247 message: T,
248 ) -> impl Future<Output = Result<()>> {
249 let this = self.clone();
250 async move {
251 let mut connection = this.connection_state(receiver_id).await?;
252 let message_id = connection
253 .next_message_id
254 .fetch_add(1, atomic::Ordering::SeqCst);
255 connection
256 .outgoing_tx
257 .send(message.into_envelope(message_id, None, None))
258 .await?;
259 Ok(())
260 }
261 }
262
263 pub fn forward_send<T: EnvelopedMessage>(
264 self: &Arc<Self>,
265 sender_id: ConnectionId,
266 receiver_id: ConnectionId,
267 message: T,
268 ) -> impl Future<Output = Result<()>> {
269 let this = self.clone();
270 async move {
271 let mut connection = this.connection_state(receiver_id).await?;
272 let message_id = connection
273 .next_message_id
274 .fetch_add(1, atomic::Ordering::SeqCst);
275 connection
276 .outgoing_tx
277 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
278 .await?;
279 Ok(())
280 }
281 }
282
283 pub fn respond<T: RequestMessage>(
284 self: &Arc<Self>,
285 receipt: Receipt<T>,
286 response: T::Response,
287 ) -> impl Future<Output = Result<()>> {
288 let this = self.clone();
289 async move {
290 let mut connection = this.connection_state(receipt.sender_id).await?;
291 let message_id = connection
292 .next_message_id
293 .fetch_add(1, atomic::Ordering::SeqCst);
294 connection
295 .outgoing_tx
296 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
297 .await?;
298 Ok(())
299 }
300 }
301
302 pub fn respond_with_error<T: RequestMessage>(
303 self: &Arc<Self>,
304 receipt: Receipt<T>,
305 response: proto::Error,
306 ) -> impl Future<Output = Result<()>> {
307 let this = self.clone();
308 async move {
309 let mut connection = this.connection_state(receipt.sender_id).await?;
310 let message_id = connection
311 .next_message_id
312 .fetch_add(1, atomic::Ordering::SeqCst);
313 connection
314 .outgoing_tx
315 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
316 .await?;
317 Ok(())
318 }
319 }
320
321 fn connection_state(
322 self: &Arc<Self>,
323 connection_id: ConnectionId,
324 ) -> impl Future<Output = Result<ConnectionState>> {
325 let this = self.clone();
326 async move {
327 let connections = this.connections.read().await;
328 let connection = connections
329 .get(&connection_id)
330 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
331 Ok(connection.clone())
332 }
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use crate::TypedEnvelope;
340 use async_tungstenite::tungstenite::Message as WebSocketMessage;
341 use futures::StreamExt as _;
342
343 #[test]
344 fn test_request_response() {
345 smol::block_on(async move {
346 // create 2 clients connected to 1 server
347 let server = Peer::new();
348 let client1 = Peer::new();
349 let client2 = Peer::new();
350
351 let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
352 let (client1_conn_id, io_task1, _) =
353 client1.add_connection(client1_to_server_conn).await;
354 let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
355
356 let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
357 let (client2_conn_id, io_task3, _) =
358 client2.add_connection(client2_to_server_conn).await;
359 let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
360
361 smol::spawn(io_task1).detach();
362 smol::spawn(io_task2).detach();
363 smol::spawn(io_task3).detach();
364 smol::spawn(io_task4).detach();
365 smol::spawn(handle_messages(incoming1, server.clone())).detach();
366 smol::spawn(handle_messages(incoming2, server.clone())).detach();
367
368 assert_eq!(
369 client1
370 .request(client1_conn_id, proto::Ping {},)
371 .await
372 .unwrap(),
373 proto::Ack {}
374 );
375
376 assert_eq!(
377 client2
378 .request(client2_conn_id, proto::Ping {},)
379 .await
380 .unwrap(),
381 proto::Ack {}
382 );
383
384 assert_eq!(
385 client1
386 .request(
387 client1_conn_id,
388 proto::OpenBuffer {
389 worktree_id: 1,
390 path: "path/one".to_string(),
391 },
392 )
393 .await
394 .unwrap(),
395 proto::OpenBufferResponse {
396 buffer: Some(proto::Buffer {
397 id: 101,
398 content: "path/one content".to_string(),
399 history: vec![],
400 selections: vec![],
401 }),
402 }
403 );
404
405 assert_eq!(
406 client2
407 .request(
408 client2_conn_id,
409 proto::OpenBuffer {
410 worktree_id: 2,
411 path: "path/two".to_string(),
412 },
413 )
414 .await
415 .unwrap(),
416 proto::OpenBufferResponse {
417 buffer: Some(proto::Buffer {
418 id: 102,
419 content: "path/two content".to_string(),
420 history: vec![],
421 selections: vec![],
422 }),
423 }
424 );
425
426 client1.disconnect(client1_conn_id).await;
427 client2.disconnect(client1_conn_id).await;
428
429 async fn handle_messages(
430 mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
431 peer: Arc<Peer>,
432 ) -> Result<()> {
433 while let Some(envelope) = messages.next().await {
434 let envelope = envelope.into_any();
435 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
436 let receipt = envelope.receipt();
437 peer.respond(receipt, proto::Ack {}).await?
438 } else if let Some(envelope) =
439 envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
440 {
441 let message = &envelope.payload;
442 let receipt = envelope.receipt();
443 let response = match message.path.as_str() {
444 "path/one" => {
445 assert_eq!(message.worktree_id, 1);
446 proto::OpenBufferResponse {
447 buffer: Some(proto::Buffer {
448 id: 101,
449 content: "path/one content".to_string(),
450 history: vec![],
451 selections: vec![],
452 }),
453 }
454 }
455 "path/two" => {
456 assert_eq!(message.worktree_id, 2);
457 proto::OpenBufferResponse {
458 buffer: Some(proto::Buffer {
459 id: 102,
460 content: "path/two content".to_string(),
461 history: vec![],
462 selections: vec![],
463 }),
464 }
465 }
466 _ => {
467 panic!("unexpected path {}", message.path);
468 }
469 };
470
471 peer.respond(receipt, response).await?
472 } else {
473 panic!("unknown message type");
474 }
475 }
476
477 Ok(())
478 }
479 });
480 }
481
482 #[test]
483 fn test_disconnect() {
484 smol::block_on(async move {
485 let (client_conn, mut server_conn, _) = Connection::in_memory();
486
487 let client = Peer::new();
488 let (connection_id, io_handler, mut incoming) =
489 client.add_connection(client_conn).await;
490
491 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
492 smol::spawn(async move {
493 io_handler.await.ok();
494 io_ended_tx.send(()).await.unwrap();
495 })
496 .detach();
497
498 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
499 smol::spawn(async move {
500 incoming.next().await;
501 messages_ended_tx.send(()).await.unwrap();
502 })
503 .detach();
504
505 client.disconnect(connection_id).await;
506
507 io_ended_rx.recv().await;
508 messages_ended_rx.recv().await;
509 assert!(server_conn
510 .send(WebSocketMessage::Binary(vec![]))
511 .await
512 .is_err());
513 });
514 }
515
516 #[test]
517 fn test_io_error() {
518 smol::block_on(async move {
519 let (client_conn, mut server_conn, _) = Connection::in_memory();
520
521 let client = Peer::new();
522 let (connection_id, io_handler, mut incoming) =
523 client.add_connection(client_conn).await;
524 smol::spawn(io_handler).detach();
525 smol::spawn(async move { incoming.next().await }).detach();
526
527 let response = smol::spawn(client.request(connection_id, proto::Ping {}));
528 let _request = server_conn.rx.next().await.unwrap().unwrap();
529
530 drop(server_conn);
531 assert_eq!(
532 response.await.unwrap_err().to_string(),
533 "connection was closed"
534 );
535 });
536 }
537}