1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use async_lock::{Mutex, RwLock};
5use futures::FutureExt as _;
6use postage::{
7 mpsc,
8 prelude::{Sink as _, Stream as _},
9};
10use smol_timeout::TimeoutExt as _;
11use std::{
12 collections::HashMap,
13 fmt,
14 future::Future,
15 marker::PhantomData,
16 sync::{
17 atomic::{self, AtomicU32},
18 Arc,
19 },
20 time::Duration,
21};
22
23#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
24pub struct ConnectionId(pub u32);
25
26impl fmt::Display for ConnectionId {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 self.0.fmt(f)
29 }
30}
31
32#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
33pub struct PeerId(pub u32);
34
35impl fmt::Display for PeerId {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 self.0.fmt(f)
38 }
39}
40
41pub struct Receipt<T> {
42 pub sender_id: ConnectionId,
43 pub message_id: u32,
44 payload_type: PhantomData<T>,
45}
46
47impl<T> Clone for Receipt<T> {
48 fn clone(&self) -> Self {
49 Self {
50 sender_id: self.sender_id,
51 message_id: self.message_id,
52 payload_type: PhantomData,
53 }
54 }
55}
56
57impl<T> Copy for Receipt<T> {}
58
59pub struct TypedEnvelope<T> {
60 pub sender_id: ConnectionId,
61 pub original_sender_id: Option<PeerId>,
62 pub message_id: u32,
63 pub payload: T,
64}
65
66impl<T> TypedEnvelope<T> {
67 pub fn original_sender_id(&self) -> Result<PeerId> {
68 self.original_sender_id
69 .ok_or_else(|| anyhow!("missing original_sender_id"))
70 }
71}
72
73impl<T: RequestMessage> TypedEnvelope<T> {
74 pub fn receipt(&self) -> Receipt<T> {
75 Receipt {
76 sender_id: self.sender_id,
77 message_id: self.message_id,
78 payload_type: PhantomData,
79 }
80 }
81}
82
83pub struct Peer {
84 connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
85 next_connection_id: AtomicU32,
86}
87
88#[derive(Clone)]
89struct ConnectionState {
90 outgoing_tx: mpsc::Sender<proto::Envelope>,
91 next_message_id: Arc<AtomicU32>,
92 response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
93}
94
95const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
96
97impl Peer {
98 pub fn new() -> Arc<Self> {
99 Arc::new(Self {
100 connections: Default::default(),
101 next_connection_id: Default::default(),
102 })
103 }
104
105 pub async fn add_connection(
106 self: &Arc<Self>,
107 connection: Connection,
108 ) -> (
109 ConnectionId,
110 impl Future<Output = anyhow::Result<()>> + Send,
111 mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
112 ) {
113 let connection_id = ConnectionId(
114 self.next_connection_id
115 .fetch_add(1, atomic::Ordering::SeqCst),
116 );
117 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
118 let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
119 let connection_state = ConnectionState {
120 outgoing_tx,
121 next_message_id: Default::default(),
122 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
123 };
124 let mut writer = MessageStream::new(connection.tx);
125 let mut reader = MessageStream::new(connection.rx);
126
127 let this = self.clone();
128 let response_channels = connection_state.response_channels.clone();
129 let handle_io = async move {
130 let result = 'outer: loop {
131 let read_message = reader.read_message().fuse();
132 futures::pin_mut!(read_message);
133 loop {
134 futures::select_biased! {
135 incoming = read_message => match incoming {
136 Ok(incoming) => {
137 if let Some(responding_to) = incoming.responding_to {
138 let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to);
139 if let Some(mut tx) = channel {
140 tx.send(incoming).await.ok();
141 } else {
142 log::warn!("received RPC response to unknown request {}", responding_to);
143 }
144 } else {
145 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
146 if incoming_tx.send(envelope).await.is_err() {
147 break 'outer Ok(())
148 }
149 } else {
150 log::error!("unable to construct a typed envelope");
151 }
152 }
153
154 break;
155 }
156 Err(error) => {
157 break 'outer Err(error).context("received invalid RPC message")
158 }
159 },
160 outgoing = outgoing_rx.recv().fuse() => match outgoing {
161 Some(outgoing) => {
162 match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
163 None => break 'outer Err(anyhow!("timed out writing RPC message")),
164 Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
165 _ => {}
166 }
167 }
168 None => break 'outer Ok(()),
169 }
170 }
171 }
172 };
173
174 response_channels.lock().await.take();
175 this.connections.write().await.remove(&connection_id);
176 result
177 };
178
179 self.connections
180 .write()
181 .await
182 .insert(connection_id, connection_state);
183
184 (connection_id, handle_io, incoming_rx)
185 }
186
187 pub async fn disconnect(&self, connection_id: ConnectionId) {
188 self.connections.write().await.remove(&connection_id);
189 }
190
191 pub async fn reset(&self) {
192 self.connections.write().await.clear();
193 }
194
195 pub fn request<T: RequestMessage>(
196 self: &Arc<Self>,
197 receiver_id: ConnectionId,
198 request: T,
199 ) -> impl Future<Output = Result<T::Response>> {
200 self.request_internal(None, receiver_id, request)
201 }
202
203 pub fn forward_request<T: RequestMessage>(
204 self: &Arc<Self>,
205 sender_id: ConnectionId,
206 receiver_id: ConnectionId,
207 request: T,
208 ) -> impl Future<Output = Result<T::Response>> {
209 self.request_internal(Some(sender_id), receiver_id, request)
210 }
211
212 pub fn request_internal<T: RequestMessage>(
213 self: &Arc<Self>,
214 original_sender_id: Option<ConnectionId>,
215 receiver_id: ConnectionId,
216 request: T,
217 ) -> impl Future<Output = Result<T::Response>> {
218 let this = self.clone();
219 let (tx, mut rx) = mpsc::channel(1);
220 async move {
221 let mut connection = this.connection_state(receiver_id).await?;
222 let message_id = connection
223 .next_message_id
224 .fetch_add(1, atomic::Ordering::SeqCst);
225 connection
226 .response_channels
227 .lock()
228 .await
229 .as_mut()
230 .ok_or_else(|| anyhow!("connection was closed"))?
231 .insert(message_id, tx);
232 connection
233 .outgoing_tx
234 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
235 .await
236 .map_err(|_| anyhow!("connection was closed"))?;
237 let response = rx
238 .recv()
239 .await
240 .ok_or_else(|| anyhow!("connection was closed"))?;
241 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
242 Err(anyhow!("request failed").context(error.message.clone()))
243 } else {
244 T::Response::from_envelope(response)
245 .ok_or_else(|| anyhow!("received response of the wrong type"))
246 }
247 }
248 }
249
250 pub fn send<T: EnvelopedMessage>(
251 self: &Arc<Self>,
252 receiver_id: ConnectionId,
253 message: T,
254 ) -> impl Future<Output = Result<()>> {
255 let this = self.clone();
256 async move {
257 let mut connection = this.connection_state(receiver_id).await?;
258 let message_id = connection
259 .next_message_id
260 .fetch_add(1, atomic::Ordering::SeqCst);
261 connection
262 .outgoing_tx
263 .send(message.into_envelope(message_id, None, None))
264 .await?;
265 Ok(())
266 }
267 }
268
269 pub fn forward_send<T: EnvelopedMessage>(
270 self: &Arc<Self>,
271 sender_id: ConnectionId,
272 receiver_id: ConnectionId,
273 message: T,
274 ) -> impl Future<Output = Result<()>> {
275 let this = self.clone();
276 async move {
277 let mut connection = this.connection_state(receiver_id).await?;
278 let message_id = connection
279 .next_message_id
280 .fetch_add(1, atomic::Ordering::SeqCst);
281 connection
282 .outgoing_tx
283 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
284 .await?;
285 Ok(())
286 }
287 }
288
289 pub fn respond<T: RequestMessage>(
290 self: &Arc<Self>,
291 receipt: Receipt<T>,
292 response: T::Response,
293 ) -> impl Future<Output = Result<()>> {
294 let this = self.clone();
295 async move {
296 let mut connection = this.connection_state(receipt.sender_id).await?;
297 let message_id = connection
298 .next_message_id
299 .fetch_add(1, atomic::Ordering::SeqCst);
300 connection
301 .outgoing_tx
302 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
303 .await?;
304 Ok(())
305 }
306 }
307
308 pub fn respond_with_error<T: RequestMessage>(
309 self: &Arc<Self>,
310 receipt: Receipt<T>,
311 response: proto::Error,
312 ) -> impl Future<Output = Result<()>> {
313 let this = self.clone();
314 async move {
315 let mut connection = this.connection_state(receipt.sender_id).await?;
316 let message_id = connection
317 .next_message_id
318 .fetch_add(1, atomic::Ordering::SeqCst);
319 connection
320 .outgoing_tx
321 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
322 .await?;
323 Ok(())
324 }
325 }
326
327 fn connection_state(
328 self: &Arc<Self>,
329 connection_id: ConnectionId,
330 ) -> impl Future<Output = Result<ConnectionState>> {
331 let this = self.clone();
332 async move {
333 let connections = this.connections.read().await;
334 let connection = connections
335 .get(&connection_id)
336 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
337 Ok(connection.clone())
338 }
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use crate::TypedEnvelope;
346 use async_tungstenite::tungstenite::Message as WebSocketMessage;
347 use futures::StreamExt as _;
348
349 #[test]
350 fn test_request_response() {
351 smol::block_on(async move {
352 // create 2 clients connected to 1 server
353 let server = Peer::new();
354 let client1 = Peer::new();
355 let client2 = Peer::new();
356
357 let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
358 let (client1_conn_id, io_task1, _) =
359 client1.add_connection(client1_to_server_conn).await;
360 let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
361
362 let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
363 let (client2_conn_id, io_task3, _) =
364 client2.add_connection(client2_to_server_conn).await;
365 let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
366
367 smol::spawn(io_task1).detach();
368 smol::spawn(io_task2).detach();
369 smol::spawn(io_task3).detach();
370 smol::spawn(io_task4).detach();
371 smol::spawn(handle_messages(incoming1, server.clone())).detach();
372 smol::spawn(handle_messages(incoming2, server.clone())).detach();
373
374 assert_eq!(
375 client1
376 .request(client1_conn_id, proto::Ping {},)
377 .await
378 .unwrap(),
379 proto::Ack {}
380 );
381
382 assert_eq!(
383 client2
384 .request(client2_conn_id, proto::Ping {},)
385 .await
386 .unwrap(),
387 proto::Ack {}
388 );
389
390 assert_eq!(
391 client1
392 .request(
393 client1_conn_id,
394 proto::OpenBuffer {
395 worktree_id: 1,
396 path: "path/one".to_string(),
397 },
398 )
399 .await
400 .unwrap(),
401 proto::OpenBufferResponse {
402 buffer: Some(proto::Buffer {
403 id: 101,
404 content: "path/one content".to_string(),
405 history: vec![],
406 selections: vec![],
407 diagnostics: None,
408 }),
409 }
410 );
411
412 assert_eq!(
413 client2
414 .request(
415 client2_conn_id,
416 proto::OpenBuffer {
417 worktree_id: 2,
418 path: "path/two".to_string(),
419 },
420 )
421 .await
422 .unwrap(),
423 proto::OpenBufferResponse {
424 buffer: Some(proto::Buffer {
425 id: 102,
426 content: "path/two content".to_string(),
427 history: vec![],
428 selections: vec![],
429 diagnostics: None,
430 }),
431 }
432 );
433
434 client1.disconnect(client1_conn_id).await;
435 client2.disconnect(client1_conn_id).await;
436
437 async fn handle_messages(
438 mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
439 peer: Arc<Peer>,
440 ) -> Result<()> {
441 while let Some(envelope) = messages.next().await {
442 let envelope = envelope.into_any();
443 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
444 let receipt = envelope.receipt();
445 peer.respond(receipt, proto::Ack {}).await?
446 } else if let Some(envelope) =
447 envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
448 {
449 let message = &envelope.payload;
450 let receipt = envelope.receipt();
451 let response = match message.path.as_str() {
452 "path/one" => {
453 assert_eq!(message.worktree_id, 1);
454 proto::OpenBufferResponse {
455 buffer: Some(proto::Buffer {
456 id: 101,
457 content: "path/one content".to_string(),
458 history: vec![],
459 selections: vec![],
460 diagnostics: None,
461 }),
462 }
463 }
464 "path/two" => {
465 assert_eq!(message.worktree_id, 2);
466 proto::OpenBufferResponse {
467 buffer: Some(proto::Buffer {
468 id: 102,
469 content: "path/two content".to_string(),
470 history: vec![],
471 selections: vec![],
472 diagnostics: None,
473 }),
474 }
475 }
476 _ => {
477 panic!("unexpected path {}", message.path);
478 }
479 };
480
481 peer.respond(receipt, response).await?
482 } else {
483 panic!("unknown message type");
484 }
485 }
486
487 Ok(())
488 }
489 });
490 }
491
492 #[test]
493 fn test_disconnect() {
494 smol::block_on(async move {
495 let (client_conn, mut server_conn, _) = Connection::in_memory();
496
497 let client = Peer::new();
498 let (connection_id, io_handler, mut incoming) =
499 client.add_connection(client_conn).await;
500
501 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
502 smol::spawn(async move {
503 io_handler.await.ok();
504 io_ended_tx.send(()).await.unwrap();
505 })
506 .detach();
507
508 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
509 smol::spawn(async move {
510 incoming.next().await;
511 messages_ended_tx.send(()).await.unwrap();
512 })
513 .detach();
514
515 client.disconnect(connection_id).await;
516
517 io_ended_rx.recv().await;
518 messages_ended_rx.recv().await;
519 assert!(server_conn
520 .send(WebSocketMessage::Binary(vec![]))
521 .await
522 .is_err());
523 });
524 }
525
526 #[test]
527 fn test_io_error() {
528 smol::block_on(async move {
529 let (client_conn, mut server_conn, _) = Connection::in_memory();
530
531 let client = Peer::new();
532 let (connection_id, io_handler, mut incoming) =
533 client.add_connection(client_conn).await;
534 smol::spawn(io_handler).detach();
535 smol::spawn(async move { incoming.next().await }).detach();
536
537 let response = smol::spawn(client.request(connection_id, proto::Ping {}));
538 let _request = server_conn.rx.next().await.unwrap().unwrap();
539
540 drop(server_conn);
541 assert_eq!(
542 response.await.unwrap_err().to_string(),
543 "connection was closed"
544 );
545 });
546 }
547}