1use crate::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Context, Result};
3use async_lock::{Mutex, RwLock};
4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
5use futures::{FutureExt, StreamExt};
6use postage::{
7 mpsc,
8 prelude::{Sink as _, Stream as _},
9};
10use std::{
11 collections::HashMap,
12 fmt,
13 future::Future,
14 marker::PhantomData,
15 sync::{
16 atomic::{self, AtomicU32},
17 Arc,
18 },
19};
20
21#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
22pub struct ConnectionId(pub u32);
23
24impl fmt::Display for ConnectionId {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 self.0.fmt(f)
27 }
28}
29
30#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
31pub struct PeerId(pub u32);
32
33impl fmt::Display for PeerId {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 self.0.fmt(f)
36 }
37}
38
39pub struct Receipt<T> {
40 pub sender_id: ConnectionId,
41 pub message_id: u32,
42 payload_type: PhantomData<T>,
43}
44
45impl<T> Clone for Receipt<T> {
46 fn clone(&self) -> Self {
47 Self {
48 sender_id: self.sender_id,
49 message_id: self.message_id,
50 payload_type: PhantomData,
51 }
52 }
53}
54
55impl<T> Copy for Receipt<T> {}
56
57pub struct TypedEnvelope<T> {
58 pub sender_id: ConnectionId,
59 pub original_sender_id: Option<PeerId>,
60 pub message_id: u32,
61 pub payload: T,
62}
63
64impl<T> TypedEnvelope<T> {
65 pub fn original_sender_id(&self) -> Result<PeerId> {
66 self.original_sender_id
67 .ok_or_else(|| anyhow!("missing original_sender_id"))
68 }
69}
70
71impl<T: RequestMessage> TypedEnvelope<T> {
72 pub fn receipt(&self) -> Receipt<T> {
73 Receipt {
74 sender_id: self.sender_id,
75 message_id: self.message_id,
76 payload_type: PhantomData,
77 }
78 }
79}
80
81pub struct Peer {
82 connections: RwLock<HashMap<ConnectionId, Connection>>,
83 next_connection_id: AtomicU32,
84}
85
86#[derive(Clone)]
87struct Connection {
88 outgoing_tx: mpsc::Sender<proto::Envelope>,
89 next_message_id: Arc<AtomicU32>,
90 response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
91}
92
93impl Peer {
94 pub fn new() -> Arc<Self> {
95 Arc::new(Self {
96 connections: Default::default(),
97 next_connection_id: Default::default(),
98 })
99 }
100
101 pub async fn add_connection<Conn>(
102 self: &Arc<Self>,
103 conn: Conn,
104 ) -> (
105 ConnectionId,
106 impl Future<Output = anyhow::Result<()>> + Send,
107 mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
108 )
109 where
110 Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
111 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
112 + Send
113 + Unpin,
114 {
115 let (tx, rx) = conn.split();
116 let connection_id = ConnectionId(
117 self.next_connection_id
118 .fetch_add(1, atomic::Ordering::SeqCst),
119 );
120 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
121 let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
122 let connection = Connection {
123 outgoing_tx,
124 next_message_id: Default::default(),
125 response_channels: Default::default(),
126 };
127 let mut writer = MessageStream::new(tx);
128 let mut reader = MessageStream::new(rx);
129
130 let this = self.clone();
131 let response_channels = connection.response_channels.clone();
132 let handle_io = async move {
133 loop {
134 let read_message = reader.read_message().fuse();
135 futures::pin_mut!(read_message);
136 loop {
137 futures::select_biased! {
138 incoming = read_message => match incoming {
139 Ok(incoming) => {
140 if let Some(responding_to) = incoming.responding_to {
141 let channel = response_channels.lock().await.remove(&responding_to);
142 if let Some(mut tx) = channel {
143 tx.send(incoming).await.ok();
144 } else {
145 log::warn!("received RPC response to unknown request {}", responding_to);
146 }
147 } else {
148 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
149 if incoming_tx.send(envelope).await.is_err() {
150 response_channels.lock().await.clear();
151 this.connections.write().await.remove(&connection_id);
152 return Ok(())
153 }
154 } else {
155 log::error!("unable to construct a typed envelope");
156 }
157 }
158
159 break;
160 }
161 Err(error) => {
162 response_channels.lock().await.clear();
163 this.connections.write().await.remove(&connection_id);
164 Err(error).context("received invalid RPC message")?;
165 }
166 },
167 outgoing = outgoing_rx.recv().fuse() => match outgoing {
168 Some(outgoing) => {
169 if let Err(result) = writer.write_message(&outgoing).await {
170 response_channels.lock().await.clear();
171 this.connections.write().await.remove(&connection_id);
172 Err(result).context("failed to write RPC message")?;
173 }
174 }
175 None => {
176 response_channels.lock().await.clear();
177 this.connections.write().await.remove(&connection_id);
178 return Ok(())
179 }
180 }
181 }
182 }
183 }
184 };
185
186 self.connections
187 .write()
188 .await
189 .insert(connection_id, connection);
190
191 (connection_id, handle_io, incoming_rx)
192 }
193
194 pub async fn disconnect(&self, connection_id: ConnectionId) {
195 self.connections.write().await.remove(&connection_id);
196 }
197
198 pub async fn reset(&self) {
199 self.connections.write().await.clear();
200 }
201
202 pub fn request<T: RequestMessage>(
203 self: &Arc<Self>,
204 receiver_id: ConnectionId,
205 request: T,
206 ) -> impl Future<Output = Result<T::Response>> {
207 self.request_internal(None, receiver_id, request)
208 }
209
210 pub fn forward_request<T: RequestMessage>(
211 self: &Arc<Self>,
212 sender_id: ConnectionId,
213 receiver_id: ConnectionId,
214 request: T,
215 ) -> impl Future<Output = Result<T::Response>> {
216 self.request_internal(Some(sender_id), receiver_id, request)
217 }
218
219 pub fn request_internal<T: RequestMessage>(
220 self: &Arc<Self>,
221 original_sender_id: Option<ConnectionId>,
222 receiver_id: ConnectionId,
223 request: T,
224 ) -> impl Future<Output = Result<T::Response>> {
225 let this = self.clone();
226 let (tx, mut rx) = mpsc::channel(1);
227 async move {
228 let mut connection = this.connection(receiver_id).await?;
229 let message_id = connection
230 .next_message_id
231 .fetch_add(1, atomic::Ordering::SeqCst);
232 connection
233 .response_channels
234 .lock()
235 .await
236 .insert(message_id, tx);
237 connection
238 .outgoing_tx
239 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
240 .await
241 .map_err(|_| anyhow!("connection was closed"))?;
242 let response = rx
243 .recv()
244 .await
245 .ok_or_else(|| anyhow!("connection was closed"))?;
246 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
247 Err(anyhow!("request failed").context(error.message.clone()))
248 } else {
249 T::Response::from_envelope(response)
250 .ok_or_else(|| anyhow!("received response of the wrong type"))
251 }
252 }
253 }
254
255 pub fn send<T: EnvelopedMessage>(
256 self: &Arc<Self>,
257 receiver_id: ConnectionId,
258 message: T,
259 ) -> impl Future<Output = Result<()>> {
260 let this = self.clone();
261 async move {
262 let mut connection = this.connection(receiver_id).await?;
263 let message_id = connection
264 .next_message_id
265 .fetch_add(1, atomic::Ordering::SeqCst);
266 connection
267 .outgoing_tx
268 .send(message.into_envelope(message_id, None, None))
269 .await?;
270 Ok(())
271 }
272 }
273
274 pub fn forward_send<T: EnvelopedMessage>(
275 self: &Arc<Self>,
276 sender_id: ConnectionId,
277 receiver_id: ConnectionId,
278 message: T,
279 ) -> impl Future<Output = Result<()>> {
280 let this = self.clone();
281 async move {
282 let mut connection = this.connection(receiver_id).await?;
283 let message_id = connection
284 .next_message_id
285 .fetch_add(1, atomic::Ordering::SeqCst);
286 connection
287 .outgoing_tx
288 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
289 .await?;
290 Ok(())
291 }
292 }
293
294 pub fn respond<T: RequestMessage>(
295 self: &Arc<Self>,
296 receipt: Receipt<T>,
297 response: T::Response,
298 ) -> impl Future<Output = Result<()>> {
299 let this = self.clone();
300 async move {
301 let mut connection = this.connection(receipt.sender_id).await?;
302 let message_id = connection
303 .next_message_id
304 .fetch_add(1, atomic::Ordering::SeqCst);
305 connection
306 .outgoing_tx
307 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
308 .await?;
309 Ok(())
310 }
311 }
312
313 pub fn respond_with_error<T: RequestMessage>(
314 self: &Arc<Self>,
315 receipt: Receipt<T>,
316 response: proto::Error,
317 ) -> impl Future<Output = Result<()>> {
318 let this = self.clone();
319 async move {
320 let mut connection = this.connection(receipt.sender_id).await?;
321 let message_id = connection
322 .next_message_id
323 .fetch_add(1, atomic::Ordering::SeqCst);
324 connection
325 .outgoing_tx
326 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
327 .await?;
328 Ok(())
329 }
330 }
331
332 fn connection(
333 self: &Arc<Self>,
334 connection_id: ConnectionId,
335 ) -> impl Future<Output = Result<Connection>> {
336 let this = self.clone();
337 async move {
338 let connections = this.connections.read().await;
339 let connection = connections
340 .get(&connection_id)
341 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
342 Ok(connection.clone())
343 }
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use crate::{test, TypedEnvelope};
351
352 #[test]
353 fn test_request_response() {
354 smol::block_on(async move {
355 // create 2 clients connected to 1 server
356 let server = Peer::new();
357 let client1 = Peer::new();
358 let client2 = Peer::new();
359
360 let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
361 let (client1_conn_id, io_task1, _) =
362 client1.add_connection(client1_to_server_conn).await;
363 let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
364
365 let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
366 let (client2_conn_id, io_task3, _) =
367 client2.add_connection(client2_to_server_conn).await;
368 let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
369
370 smol::spawn(io_task1).detach();
371 smol::spawn(io_task2).detach();
372 smol::spawn(io_task3).detach();
373 smol::spawn(io_task4).detach();
374 smol::spawn(handle_messages(incoming1, server.clone())).detach();
375 smol::spawn(handle_messages(incoming2, server.clone())).detach();
376
377 assert_eq!(
378 client1
379 .request(client1_conn_id, proto::Ping { id: 1 },)
380 .await
381 .unwrap(),
382 proto::Pong { id: 1 }
383 );
384
385 assert_eq!(
386 client2
387 .request(client2_conn_id, proto::Ping { id: 2 },)
388 .await
389 .unwrap(),
390 proto::Pong { id: 2 }
391 );
392
393 assert_eq!(
394 client1
395 .request(
396 client1_conn_id,
397 proto::OpenBuffer {
398 worktree_id: 1,
399 path: "path/one".to_string(),
400 },
401 )
402 .await
403 .unwrap(),
404 proto::OpenBufferResponse {
405 buffer: Some(proto::Buffer {
406 id: 101,
407 content: "path/one content".to_string(),
408 history: vec![],
409 selections: vec![],
410 }),
411 }
412 );
413
414 assert_eq!(
415 client2
416 .request(
417 client2_conn_id,
418 proto::OpenBuffer {
419 worktree_id: 2,
420 path: "path/two".to_string(),
421 },
422 )
423 .await
424 .unwrap(),
425 proto::OpenBufferResponse {
426 buffer: Some(proto::Buffer {
427 id: 102,
428 content: "path/two content".to_string(),
429 history: vec![],
430 selections: vec![],
431 }),
432 }
433 );
434
435 client1.disconnect(client1_conn_id).await;
436 client2.disconnect(client1_conn_id).await;
437
438 async fn handle_messages(
439 mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
440 peer: Arc<Peer>,
441 ) -> Result<()> {
442 while let Some(envelope) = messages.next().await {
443 let envelope = envelope.into_any();
444 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
445 let receipt = envelope.receipt();
446 peer.respond(
447 receipt,
448 proto::Pong {
449 id: envelope.payload.id,
450 },
451 )
452 .await?
453 } else if let Some(envelope) =
454 envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
455 {
456 let message = &envelope.payload;
457 let receipt = envelope.receipt();
458 let response = match message.path.as_str() {
459 "path/one" => {
460 assert_eq!(message.worktree_id, 1);
461 proto::OpenBufferResponse {
462 buffer: Some(proto::Buffer {
463 id: 101,
464 content: "path/one content".to_string(),
465 history: vec![],
466 selections: vec![],
467 }),
468 }
469 }
470 "path/two" => {
471 assert_eq!(message.worktree_id, 2);
472 proto::OpenBufferResponse {
473 buffer: Some(proto::Buffer {
474 id: 102,
475 content: "path/two content".to_string(),
476 history: vec![],
477 selections: vec![],
478 }),
479 }
480 }
481 _ => {
482 panic!("unexpected path {}", message.path);
483 }
484 };
485
486 peer.respond(receipt, response).await?
487 } else {
488 panic!("unknown message type");
489 }
490 }
491
492 Ok(())
493 }
494 });
495 }
496
497 #[test]
498 fn test_disconnect() {
499 smol::block_on(async move {
500 let (client_conn, mut server_conn) = test::Channel::bidirectional();
501
502 let client = Peer::new();
503 let (connection_id, io_handler, mut incoming) =
504 client.add_connection(client_conn).await;
505
506 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
507 smol::spawn(async move {
508 io_handler.await.ok();
509 io_ended_tx.send(()).await.unwrap();
510 })
511 .detach();
512
513 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
514 smol::spawn(async move {
515 incoming.next().await;
516 messages_ended_tx.send(()).await.unwrap();
517 })
518 .detach();
519
520 client.disconnect(connection_id).await;
521
522 io_ended_rx.recv().await;
523 messages_ended_rx.recv().await;
524 assert!(
525 futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
526 .await
527 .is_err()
528 );
529 });
530 }
531
532 #[test]
533 fn test_io_error() {
534 smol::block_on(async move {
535 let (client_conn, server_conn) = test::Channel::bidirectional();
536 drop(server_conn);
537
538 let client = Peer::new();
539 let (connection_id, io_handler, mut incoming) =
540 client.add_connection(client_conn).await;
541 smol::spawn(io_handler).detach();
542 smol::spawn(async move { incoming.next().await }).detach();
543
544 let err = client
545 .request(connection_id, proto::Ping { id: 42 })
546 .await
547 .unwrap_err();
548 assert_eq!(err.to_string(), "connection was closed");
549 });
550 }
551}