1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use async_lock::{Mutex, RwLock};
5use futures::FutureExt as _;
6use postage::{
7 mpsc,
8 prelude::{Sink as _, Stream as _},
9};
10use smol_timeout::TimeoutExt as _;
11use std::sync::atomic::Ordering::SeqCst;
12use std::{
13 collections::HashMap,
14 fmt,
15 future::Future,
16 marker::PhantomData,
17 sync::{
18 atomic::{self, AtomicU32},
19 Arc,
20 },
21 time::Duration,
22};
23
24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
25pub struct ConnectionId(pub u32);
26
27impl fmt::Display for ConnectionId {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 self.0.fmt(f)
30 }
31}
32
33#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
34pub struct PeerId(pub u32);
35
36impl fmt::Display for PeerId {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 self.0.fmt(f)
39 }
40}
41
42pub struct Receipt<T> {
43 pub sender_id: ConnectionId,
44 pub message_id: u32,
45 payload_type: PhantomData<T>,
46}
47
48impl<T> Clone for Receipt<T> {
49 fn clone(&self) -> Self {
50 Self {
51 sender_id: self.sender_id,
52 message_id: self.message_id,
53 payload_type: PhantomData,
54 }
55 }
56}
57
58impl<T> Copy for Receipt<T> {}
59
60pub struct TypedEnvelope<T> {
61 pub sender_id: ConnectionId,
62 pub original_sender_id: Option<PeerId>,
63 pub message_id: u32,
64 pub payload: T,
65}
66
67impl<T> TypedEnvelope<T> {
68 pub fn original_sender_id(&self) -> Result<PeerId> {
69 self.original_sender_id
70 .ok_or_else(|| anyhow!("missing original_sender_id"))
71 }
72}
73
74impl<T: RequestMessage> TypedEnvelope<T> {
75 pub fn receipt(&self) -> Receipt<T> {
76 Receipt {
77 sender_id: self.sender_id,
78 message_id: self.message_id,
79 payload_type: PhantomData,
80 }
81 }
82}
83
84pub struct Peer {
85 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
86 next_connection_id: AtomicU32,
87}
88
89#[derive(Clone)]
90pub struct ConnectionState {
91 outgoing_tx: mpsc::Sender<proto::Envelope>,
92 next_message_id: Arc<AtomicU32>,
93 response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
94}
95
96const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
97
98impl Peer {
99 pub fn new() -> Arc<Self> {
100 Arc::new(Self {
101 connections: Default::default(),
102 next_connection_id: Default::default(),
103 })
104 }
105
106 pub async fn add_connection(
107 self: &Arc<Self>,
108 connection: Connection,
109 ) -> (
110 ConnectionId,
111 impl Future<Output = anyhow::Result<()>> + Send,
112 mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
113 ) {
114 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
115 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
116 let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
117 let connection_state = ConnectionState {
118 outgoing_tx,
119 next_message_id: Default::default(),
120 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
121 };
122 let mut writer = MessageStream::new(connection.tx);
123 let mut reader = MessageStream::new(connection.rx);
124
125 let this = self.clone();
126 let response_channels = connection_state.response_channels.clone();
127 let handle_io = async move {
128 let result = 'outer: loop {
129 let read_message = reader.read_message().fuse();
130 futures::pin_mut!(read_message);
131 loop {
132 futures::select_biased! {
133 incoming = read_message => match incoming {
134 Ok(incoming) => {
135 if let Some(responding_to) = incoming.responding_to {
136 let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to);
137 if let Some(mut tx) = channel {
138 tx.send(incoming).await.ok();
139 } else {
140 log::warn!("received RPC response to unknown request {}", responding_to);
141 }
142 } else {
143 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
144 if incoming_tx.send(envelope).await.is_err() {
145 break 'outer Ok(())
146 }
147 } else {
148 log::error!("unable to construct a typed envelope");
149 }
150 }
151
152 break;
153 }
154 Err(error) => {
155 break 'outer Err(error).context("received invalid RPC message")
156 }
157 },
158 outgoing = outgoing_rx.recv().fuse() => match outgoing {
159 Some(outgoing) => {
160 match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
161 None => break 'outer Err(anyhow!("timed out writing RPC message")),
162 Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
163 _ => {}
164 }
165 }
166 None => break 'outer Ok(()),
167 }
168 }
169 }
170 };
171
172 response_channels.lock().await.take();
173 this.connections.write().await.remove(&connection_id);
174 result
175 };
176
177 self.connections
178 .write()
179 .await
180 .insert(connection_id, connection_state);
181
182 (connection_id, handle_io, incoming_rx)
183 }
184
185 pub async fn disconnect(&self, connection_id: ConnectionId) {
186 self.connections.write().await.remove(&connection_id);
187 }
188
189 pub async fn reset(&self) {
190 self.connections.write().await.clear();
191 }
192
193 pub fn request<T: RequestMessage>(
194 self: &Arc<Self>,
195 receiver_id: ConnectionId,
196 request: T,
197 ) -> impl Future<Output = Result<T::Response>> {
198 self.request_internal(None, receiver_id, request)
199 }
200
201 pub fn forward_request<T: RequestMessage>(
202 self: &Arc<Self>,
203 sender_id: ConnectionId,
204 receiver_id: ConnectionId,
205 request: T,
206 ) -> impl Future<Output = Result<T::Response>> {
207 self.request_internal(Some(sender_id), receiver_id, request)
208 }
209
210 pub fn request_internal<T: RequestMessage>(
211 self: &Arc<Self>,
212 original_sender_id: Option<ConnectionId>,
213 receiver_id: ConnectionId,
214 request: T,
215 ) -> impl Future<Output = Result<T::Response>> {
216 let this = self.clone();
217 let (tx, mut rx) = mpsc::channel(1);
218 async move {
219 let mut connection = this.connection_state(receiver_id).await?;
220 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
221 connection
222 .response_channels
223 .lock()
224 .await
225 .as_mut()
226 .ok_or_else(|| anyhow!("connection was closed"))?
227 .insert(message_id, tx);
228 connection
229 .outgoing_tx
230 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
231 .await
232 .map_err(|_| anyhow!("connection was closed"))?;
233 let response = rx
234 .recv()
235 .await
236 .ok_or_else(|| anyhow!("connection was closed"))?;
237 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
238 Err(anyhow!("request failed").context(error.message.clone()))
239 } else {
240 T::Response::from_envelope(response)
241 .ok_or_else(|| anyhow!("received response of the wrong type"))
242 }
243 }
244 }
245
246 pub fn send<T: EnvelopedMessage>(
247 self: &Arc<Self>,
248 receiver_id: ConnectionId,
249 message: T,
250 ) -> impl Future<Output = Result<()>> {
251 let this = self.clone();
252 async move {
253 let mut connection = this.connection_state(receiver_id).await?;
254 let message_id = connection
255 .next_message_id
256 .fetch_add(1, atomic::Ordering::SeqCst);
257 connection
258 .outgoing_tx
259 .send(message.into_envelope(message_id, None, None))
260 .await?;
261 Ok(())
262 }
263 }
264
265 pub fn forward_send<T: EnvelopedMessage>(
266 self: &Arc<Self>,
267 sender_id: ConnectionId,
268 receiver_id: ConnectionId,
269 message: T,
270 ) -> impl Future<Output = Result<()>> {
271 let this = self.clone();
272 async move {
273 let mut connection = this.connection_state(receiver_id).await?;
274 let message_id = connection
275 .next_message_id
276 .fetch_add(1, atomic::Ordering::SeqCst);
277 connection
278 .outgoing_tx
279 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
280 .await?;
281 Ok(())
282 }
283 }
284
285 pub fn respond<T: RequestMessage>(
286 self: &Arc<Self>,
287 receipt: Receipt<T>,
288 response: T::Response,
289 ) -> impl Future<Output = Result<()>> {
290 let this = self.clone();
291 async move {
292 let mut connection = this.connection_state(receipt.sender_id).await?;
293 let message_id = connection
294 .next_message_id
295 .fetch_add(1, atomic::Ordering::SeqCst);
296 connection
297 .outgoing_tx
298 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
299 .await?;
300 Ok(())
301 }
302 }
303
304 pub fn respond_with_error<T: RequestMessage>(
305 self: &Arc<Self>,
306 receipt: Receipt<T>,
307 response: proto::Error,
308 ) -> impl Future<Output = Result<()>> {
309 let this = self.clone();
310 async move {
311 let mut connection = this.connection_state(receipt.sender_id).await?;
312 let message_id = connection
313 .next_message_id
314 .fetch_add(1, atomic::Ordering::SeqCst);
315 connection
316 .outgoing_tx
317 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
318 .await?;
319 Ok(())
320 }
321 }
322
323 fn connection_state(
324 self: &Arc<Self>,
325 connection_id: ConnectionId,
326 ) -> impl Future<Output = Result<ConnectionState>> {
327 let this = self.clone();
328 async move {
329 let connections = this.connections.read().await;
330 let connection = connections
331 .get(&connection_id)
332 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
333 Ok(connection.clone())
334 }
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use crate::TypedEnvelope;
342 use async_tungstenite::tungstenite::Message as WebSocketMessage;
343 use futures::StreamExt as _;
344
345 #[test]
346 fn test_request_response() {
347 smol::block_on(async move {
348 // create 2 clients connected to 1 server
349 let server = Peer::new();
350 let client1 = Peer::new();
351 let client2 = Peer::new();
352
353 let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
354 let (client1_conn_id, io_task1, _) =
355 client1.add_connection(client1_to_server_conn).await;
356 let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
357
358 let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
359 let (client2_conn_id, io_task3, _) =
360 client2.add_connection(client2_to_server_conn).await;
361 let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
362
363 smol::spawn(io_task1).detach();
364 smol::spawn(io_task2).detach();
365 smol::spawn(io_task3).detach();
366 smol::spawn(io_task4).detach();
367 smol::spawn(handle_messages(incoming1, server.clone())).detach();
368 smol::spawn(handle_messages(incoming2, server.clone())).detach();
369
370 assert_eq!(
371 client1
372 .request(client1_conn_id, proto::Ping {},)
373 .await
374 .unwrap(),
375 proto::Ack {}
376 );
377
378 assert_eq!(
379 client2
380 .request(client2_conn_id, proto::Ping {},)
381 .await
382 .unwrap(),
383 proto::Ack {}
384 );
385
386 assert_eq!(
387 client1
388 .request(
389 client1_conn_id,
390 proto::OpenBuffer {
391 worktree_id: 1,
392 path: "path/one".to_string(),
393 },
394 )
395 .await
396 .unwrap(),
397 proto::OpenBufferResponse {
398 buffer: Some(proto::Buffer {
399 id: 101,
400 content: "path/one content".to_string(),
401 history: vec![],
402 selections: vec![],
403 diagnostics: None,
404 }),
405 }
406 );
407
408 assert_eq!(
409 client2
410 .request(
411 client2_conn_id,
412 proto::OpenBuffer {
413 worktree_id: 2,
414 path: "path/two".to_string(),
415 },
416 )
417 .await
418 .unwrap(),
419 proto::OpenBufferResponse {
420 buffer: Some(proto::Buffer {
421 id: 102,
422 content: "path/two content".to_string(),
423 history: vec![],
424 selections: vec![],
425 diagnostics: None,
426 }),
427 }
428 );
429
430 client1.disconnect(client1_conn_id).await;
431 client2.disconnect(client1_conn_id).await;
432
433 async fn handle_messages(
434 mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
435 peer: Arc<Peer>,
436 ) -> Result<()> {
437 while let Some(envelope) = messages.next().await {
438 let envelope = envelope.into_any();
439 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
440 let receipt = envelope.receipt();
441 peer.respond(receipt, proto::Ack {}).await?
442 } else if let Some(envelope) =
443 envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
444 {
445 let message = &envelope.payload;
446 let receipt = envelope.receipt();
447 let response = match message.path.as_str() {
448 "path/one" => {
449 assert_eq!(message.worktree_id, 1);
450 proto::OpenBufferResponse {
451 buffer: Some(proto::Buffer {
452 id: 101,
453 content: "path/one content".to_string(),
454 history: vec![],
455 selections: vec![],
456 diagnostics: None,
457 }),
458 }
459 }
460 "path/two" => {
461 assert_eq!(message.worktree_id, 2);
462 proto::OpenBufferResponse {
463 buffer: Some(proto::Buffer {
464 id: 102,
465 content: "path/two content".to_string(),
466 history: vec![],
467 selections: vec![],
468 diagnostics: None,
469 }),
470 }
471 }
472 _ => {
473 panic!("unexpected path {}", message.path);
474 }
475 };
476
477 peer.respond(receipt, response).await?
478 } else {
479 panic!("unknown message type");
480 }
481 }
482
483 Ok(())
484 }
485 });
486 }
487
488 #[test]
489 fn test_disconnect() {
490 smol::block_on(async move {
491 let (client_conn, mut server_conn, _) = Connection::in_memory();
492
493 let client = Peer::new();
494 let (connection_id, io_handler, mut incoming) =
495 client.add_connection(client_conn).await;
496
497 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
498 smol::spawn(async move {
499 io_handler.await.ok();
500 io_ended_tx.send(()).await.unwrap();
501 })
502 .detach();
503
504 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
505 smol::spawn(async move {
506 incoming.next().await;
507 messages_ended_tx.send(()).await.unwrap();
508 })
509 .detach();
510
511 client.disconnect(connection_id).await;
512
513 io_ended_rx.recv().await;
514 messages_ended_rx.recv().await;
515 assert!(server_conn
516 .send(WebSocketMessage::Binary(vec![]))
517 .await
518 .is_err());
519 });
520 }
521
522 #[test]
523 fn test_io_error() {
524 smol::block_on(async move {
525 let (client_conn, mut server_conn, _) = Connection::in_memory();
526
527 let client = Peer::new();
528 let (connection_id, io_handler, mut incoming) =
529 client.add_connection(client_conn).await;
530 smol::spawn(io_handler).detach();
531 smol::spawn(async move { incoming.next().await }).detach();
532
533 let response = smol::spawn(client.request(connection_id, proto::Ping {}));
534 let _request = server_conn.rx.next().await.unwrap().unwrap();
535
536 drop(server_conn);
537 assert_eq!(
538 response.await.unwrap_err().to_string(),
539 "connection was closed"
540 );
541 });
542 }
543}