1use crate::{
2 proto::{self, EnvelopedMessage, MessageStream, RequestMessage},
3 ConnectionId, PeerId, Receipt,
4};
5use anyhow::{anyhow, Context, Result};
6use async_lock::{Mutex, RwLock};
7use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
8use futures::{FutureExt, StreamExt};
9use postage::{
10 mpsc,
11 prelude::{Sink as _, Stream as _},
12};
13use std::{
14 any::Any,
15 collections::HashMap,
16 future::Future,
17 sync::{
18 atomic::{self, AtomicU32},
19 Arc,
20 },
21};
22
23pub struct Peer {
24 connections: RwLock<HashMap<ConnectionId, Connection>>,
25 next_connection_id: AtomicU32,
26}
27
28#[derive(Clone)]
29struct Connection {
30 outgoing_tx: mpsc::Sender<proto::Envelope>,
31 next_message_id: Arc<AtomicU32>,
32 response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
33}
34
35impl Peer {
36 pub fn new() -> Arc<Self> {
37 Arc::new(Self {
38 connections: Default::default(),
39 next_connection_id: Default::default(),
40 })
41 }
42
43 pub async fn add_connection<Conn>(
44 self: &Arc<Self>,
45 conn: Conn,
46 ) -> (
47 ConnectionId,
48 impl Future<Output = anyhow::Result<()>> + Send,
49 mpsc::Receiver<Box<dyn Any + Sync + Send>>,
50 )
51 where
52 Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
53 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
54 + Send
55 + Unpin,
56 {
57 let (tx, rx) = conn.split();
58 let connection_id = ConnectionId(
59 self.next_connection_id
60 .fetch_add(1, atomic::Ordering::SeqCst),
61 );
62 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
63 let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
64 let connection = Connection {
65 outgoing_tx,
66 next_message_id: Default::default(),
67 response_channels: Default::default(),
68 };
69 let mut writer = MessageStream::new(tx);
70 let mut reader = MessageStream::new(rx);
71
72 let response_channels = connection.response_channels.clone();
73 let handle_io = async move {
74 loop {
75 let read_message = reader.read_message().fuse();
76 futures::pin_mut!(read_message);
77 loop {
78 futures::select_biased! {
79 incoming = read_message => match incoming {
80 Ok(incoming) => {
81 if let Some(responding_to) = incoming.responding_to {
82 let channel = response_channels.lock().await.remove(&responding_to);
83 if let Some(mut tx) = channel {
84 tx.send(incoming).await.ok();
85 } else {
86 log::warn!("received RPC response to unknown request {}", responding_to);
87 }
88 } else {
89 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
90 if incoming_tx.send(envelope).await.is_err() {
91 response_channels.lock().await.clear();
92 return Ok(())
93 }
94 } else {
95 log::error!("unable to construct a typed envelope");
96 }
97 }
98
99 break;
100 }
101 Err(error) => {
102 response_channels.lock().await.clear();
103 Err(error).context("received invalid RPC message")?;
104 }
105 },
106 outgoing = outgoing_rx.recv().fuse() => match outgoing {
107 Some(outgoing) => {
108 if let Err(result) = writer.write_message(&outgoing).await {
109 response_channels.lock().await.clear();
110 Err(result).context("failed to write RPC message")?;
111 }
112 }
113 None => {
114 response_channels.lock().await.clear();
115 return Ok(())
116 }
117 }
118 }
119 }
120 }
121 };
122
123 self.connections
124 .write()
125 .await
126 .insert(connection_id, connection);
127
128 (connection_id, handle_io, incoming_rx)
129 }
130
131 pub async fn disconnect(&self, connection_id: ConnectionId) {
132 self.connections.write().await.remove(&connection_id);
133 }
134
135 pub async fn reset(&self) {
136 self.connections.write().await.clear();
137 }
138
139 pub fn request<T: RequestMessage>(
140 self: &Arc<Self>,
141 receiver_id: ConnectionId,
142 request: T,
143 ) -> impl Future<Output = Result<T::Response>> {
144 self.request_internal(None, receiver_id, request)
145 }
146
147 pub fn forward_request<T: RequestMessage>(
148 self: &Arc<Self>,
149 sender_id: ConnectionId,
150 receiver_id: ConnectionId,
151 request: T,
152 ) -> impl Future<Output = Result<T::Response>> {
153 self.request_internal(Some(sender_id), receiver_id, request)
154 }
155
156 pub fn request_internal<T: RequestMessage>(
157 self: &Arc<Self>,
158 original_sender_id: Option<ConnectionId>,
159 receiver_id: ConnectionId,
160 request: T,
161 ) -> impl Future<Output = Result<T::Response>> {
162 let this = self.clone();
163 let (tx, mut rx) = mpsc::channel(1);
164 async move {
165 let mut connection = this.connection(receiver_id).await?;
166 let message_id = connection
167 .next_message_id
168 .fetch_add(1, atomic::Ordering::SeqCst);
169 connection
170 .response_channels
171 .lock()
172 .await
173 .insert(message_id, tx);
174 connection
175 .outgoing_tx
176 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
177 .await
178 .map_err(|_| anyhow!("connection was closed"))?;
179 let response = rx
180 .recv()
181 .await
182 .ok_or_else(|| anyhow!("connection was closed"))?;
183 T::Response::from_envelope(response)
184 .ok_or_else(|| anyhow!("received response of the wrong type"))
185 }
186 }
187
188 pub fn send<T: EnvelopedMessage>(
189 self: &Arc<Self>,
190 receiver_id: ConnectionId,
191 message: T,
192 ) -> impl Future<Output = Result<()>> {
193 let this = self.clone();
194 async move {
195 let mut connection = this.connection(receiver_id).await?;
196 let message_id = connection
197 .next_message_id
198 .fetch_add(1, atomic::Ordering::SeqCst);
199 connection
200 .outgoing_tx
201 .send(message.into_envelope(message_id, None, None))
202 .await?;
203 Ok(())
204 }
205 }
206
207 pub fn forward_send<T: EnvelopedMessage>(
208 self: &Arc<Self>,
209 sender_id: ConnectionId,
210 receiver_id: ConnectionId,
211 message: T,
212 ) -> impl Future<Output = Result<()>> {
213 let this = self.clone();
214 async move {
215 let mut connection = this.connection(receiver_id).await?;
216 let message_id = connection
217 .next_message_id
218 .fetch_add(1, atomic::Ordering::SeqCst);
219 connection
220 .outgoing_tx
221 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
222 .await?;
223 Ok(())
224 }
225 }
226
227 pub fn respond<T: RequestMessage>(
228 self: &Arc<Self>,
229 receipt: Receipt<T>,
230 response: T::Response,
231 ) -> impl Future<Output = Result<()>> {
232 let this = self.clone();
233 async move {
234 let mut connection = this.connection(receipt.sender_id).await?;
235 let message_id = connection
236 .next_message_id
237 .fetch_add(1, atomic::Ordering::SeqCst);
238 connection
239 .outgoing_tx
240 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
241 .await?;
242 Ok(())
243 }
244 }
245
246 fn connection(
247 self: &Arc<Self>,
248 connection_id: ConnectionId,
249 ) -> impl Future<Output = Result<Connection>> {
250 let this = self.clone();
251 async move {
252 let connections = this.connections.read().await;
253 let connection = connections
254 .get(&connection_id)
255 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
256 Ok(connection.clone())
257 }
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::{test, TypedEnvelope};
265
266 #[test]
267 fn test_request_response() {
268 smol::block_on(async move {
269 // create 2 clients connected to 1 server
270 let server = Peer::new();
271 let client1 = Peer::new();
272 let client2 = Peer::new();
273
274 let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
275 let (client1_conn_id, io_task1, _) =
276 client1.add_connection(client1_to_server_conn).await;
277 let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
278
279 let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
280 let (client2_conn_id, io_task3, _) =
281 client2.add_connection(client2_to_server_conn).await;
282 let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
283
284 smol::spawn(io_task1).detach();
285 smol::spawn(io_task2).detach();
286 smol::spawn(io_task3).detach();
287 smol::spawn(io_task4).detach();
288 smol::spawn(handle_messages(incoming1, server.clone())).detach();
289 smol::spawn(handle_messages(incoming2, server.clone())).detach();
290
291 assert_eq!(
292 client1
293 .request(client1_conn_id, proto::Ping { id: 1 },)
294 .await
295 .unwrap(),
296 proto::Pong { id: 1 }
297 );
298
299 assert_eq!(
300 client2
301 .request(client2_conn_id, proto::Ping { id: 2 },)
302 .await
303 .unwrap(),
304 proto::Pong { id: 2 }
305 );
306
307 assert_eq!(
308 client1
309 .request(
310 client1_conn_id,
311 proto::OpenBuffer {
312 worktree_id: 1,
313 path: "path/one".to_string(),
314 },
315 )
316 .await
317 .unwrap(),
318 proto::OpenBufferResponse {
319 buffer: Some(proto::Buffer {
320 id: 101,
321 content: "path/one content".to_string(),
322 history: vec![],
323 selections: vec![],
324 }),
325 }
326 );
327
328 assert_eq!(
329 client2
330 .request(
331 client2_conn_id,
332 proto::OpenBuffer {
333 worktree_id: 2,
334 path: "path/two".to_string(),
335 },
336 )
337 .await
338 .unwrap(),
339 proto::OpenBufferResponse {
340 buffer: Some(proto::Buffer {
341 id: 102,
342 content: "path/two content".to_string(),
343 history: vec![],
344 selections: vec![],
345 }),
346 }
347 );
348
349 client1.disconnect(client1_conn_id).await;
350 client2.disconnect(client1_conn_id).await;
351
352 async fn handle_messages(
353 mut messages: mpsc::Receiver<Box<dyn Any + Sync + Send>>,
354 peer: Arc<Peer>,
355 ) -> Result<()> {
356 while let Some(envelope) = messages.next().await {
357 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
358 let receipt = envelope.receipt();
359 peer.respond(
360 receipt,
361 proto::Pong {
362 id: envelope.payload.id,
363 },
364 )
365 .await?
366 } else if let Some(envelope) =
367 envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
368 {
369 let message = &envelope.payload;
370 let receipt = envelope.receipt();
371 let response = match message.path.as_str() {
372 "path/one" => {
373 assert_eq!(message.worktree_id, 1);
374 proto::OpenBufferResponse {
375 buffer: Some(proto::Buffer {
376 id: 101,
377 content: "path/one content".to_string(),
378 history: vec![],
379 selections: vec![],
380 }),
381 }
382 }
383 "path/two" => {
384 assert_eq!(message.worktree_id, 2);
385 proto::OpenBufferResponse {
386 buffer: Some(proto::Buffer {
387 id: 102,
388 content: "path/two content".to_string(),
389 history: vec![],
390 selections: vec![],
391 }),
392 }
393 }
394 _ => {
395 panic!("unexpected path {}", message.path);
396 }
397 };
398
399 peer.respond(receipt, response).await?
400 } else {
401 panic!("unknown message type");
402 }
403 }
404
405 Ok(())
406 }
407 });
408 }
409
410 #[test]
411 fn test_disconnect() {
412 smol::block_on(async move {
413 let (client_conn, mut server_conn) = test::Channel::bidirectional();
414
415 let client = Peer::new();
416 let (connection_id, io_handler, mut incoming) =
417 client.add_connection(client_conn).await;
418
419 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
420 smol::spawn(async move {
421 io_handler.await.ok();
422 io_ended_tx.send(()).await.unwrap();
423 })
424 .detach();
425
426 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
427 smol::spawn(async move {
428 incoming.next().await;
429 messages_ended_tx.send(()).await.unwrap();
430 })
431 .detach();
432
433 client.disconnect(connection_id).await;
434
435 io_ended_rx.recv().await;
436 messages_ended_rx.recv().await;
437 assert!(
438 futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
439 .await
440 .is_err()
441 );
442 });
443 }
444
445 #[test]
446 fn test_io_error() {
447 smol::block_on(async move {
448 let (client_conn, server_conn) = test::Channel::bidirectional();
449 drop(server_conn);
450
451 let client = Peer::new();
452 let (connection_id, io_handler, mut incoming) =
453 client.add_connection(client_conn).await;
454 smol::spawn(io_handler).detach();
455 smol::spawn(async move { incoming.next().await }).detach();
456
457 let err = client
458 .request(
459 connection_id,
460 proto::Auth {
461 user_id: 42,
462 access_token: "token".to_string(),
463 },
464 )
465 .await
466 .unwrap_err();
467 assert_eq!(err.to_string(), "connection was closed");
468 });
469 }
470}