1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Result};
3use async_lock::{Mutex, RwLock};
4use futures::{
5 future::{BoxFuture, Either},
6 AsyncRead, AsyncWrite, FutureExt,
7};
8use postage::{
9 barrier, mpsc, oneshot,
10 prelude::{Sink, Stream},
11};
12use std::{
13 any::TypeId,
14 collections::{HashMap, HashSet},
15 future::Future,
16 pin::Pin,
17 sync::{
18 atomic::{self, AtomicU32},
19 Arc,
20 },
21};
22
23type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
24type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
25
26#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
27pub struct ConnectionId(u32);
28
29struct Connection {
30 writer: Mutex<MessageStream<BoxedWriter>>,
31 reader: Mutex<MessageStream<BoxedReader>>,
32 response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
33 next_message_id: AtomicU32,
34}
35
36type MessageHandler = Box<
37 dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
38>;
39
40pub struct TypedEnvelope<T> {
41 pub id: u32,
42 pub connection_id: ConnectionId,
43 pub payload: T,
44}
45
46pub struct Peer {
47 connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
48 connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
49 message_handlers: RwLock<Vec<MessageHandler>>,
50 handler_types: Mutex<HashSet<TypeId>>,
51 next_connection_id: AtomicU32,
52}
53
54impl Peer {
55 pub fn new() -> Arc<Self> {
56 Arc::new(Self {
57 connections: Default::default(),
58 connection_close_barriers: Default::default(),
59 message_handlers: Default::default(),
60 handler_types: Default::default(),
61 next_connection_id: Default::default(),
62 })
63 }
64
65 pub async fn add_message_handler<T: EnvelopedMessage>(
66 &self,
67 ) -> mpsc::Receiver<TypedEnvelope<T>> {
68 if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
69 panic!("duplicate handler type");
70 }
71
72 let (tx, rx) = mpsc::channel(256);
73 self.message_handlers
74 .write()
75 .await
76 .push(Box::new(move |envelope, connection_id| {
77 if envelope.as_ref().map_or(false, T::matches_envelope) {
78 let envelope = Option::take(envelope).unwrap();
79 let mut tx = tx.clone();
80 Some(
81 async move {
82 tx.send(TypedEnvelope {
83 id: envelope.id,
84 connection_id,
85 payload: T::from_envelope(envelope).unwrap(),
86 })
87 .await
88 .is_err()
89 }
90 .boxed(),
91 )
92 } else {
93 None
94 }
95 }));
96 rx
97 }
98
99 pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
100 where
101 Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
102 {
103 let connection_id = ConnectionId(
104 self.next_connection_id
105 .fetch_add(1, atomic::Ordering::SeqCst),
106 );
107 self.connections.write().await.insert(
108 connection_id,
109 Arc::new(Connection {
110 reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
111 writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
112 response_channels: Default::default(),
113 next_message_id: Default::default(),
114 }),
115 );
116 connection_id
117 }
118
119 pub async fn disconnect(&self, connection_id: ConnectionId) {
120 self.connections.write().await.remove(&connection_id);
121 self.connection_close_barriers
122 .write()
123 .await
124 .remove(&connection_id);
125 }
126
127 pub fn handle_messages(
128 self: &Arc<Self>,
129 connection_id: ConnectionId,
130 ) -> impl Future<Output = Result<()>> + 'static {
131 let (close_tx, mut close_rx) = barrier::channel();
132 let this = self.clone();
133 async move {
134 this.connection_close_barriers
135 .write()
136 .await
137 .insert(connection_id, close_tx);
138 let connection = this.connection(connection_id).await?;
139 let closed = close_rx.recv();
140 futures::pin_mut!(closed);
141
142 loop {
143 let mut reader = connection.reader.lock().await;
144 let read_message = reader.read_message();
145 futures::pin_mut!(read_message);
146
147 match futures::future::select(read_message, &mut closed).await {
148 Either::Left((Ok(incoming), _)) => {
149 if let Some(responding_to) = incoming.responding_to {
150 let channel = connection
151 .response_channels
152 .lock()
153 .await
154 .remove(&responding_to);
155 if let Some(mut tx) = channel {
156 tx.send(incoming).await.ok();
157 } else {
158 log::warn!(
159 "received RPC response to unknown request {}",
160 responding_to
161 );
162 }
163 } else {
164 let mut envelope = Some(incoming);
165 let mut handler_index = None;
166 let mut handler_was_dropped = false;
167 for (i, handler) in
168 this.message_handlers.read().await.iter().enumerate()
169 {
170 if let Some(future) = handler(&mut envelope, connection_id) {
171 handler_was_dropped = future.await;
172 handler_index = Some(i);
173 break;
174 }
175 }
176
177 if let Some(handler_index) = handler_index {
178 if handler_was_dropped {
179 drop(this.message_handlers.write().await.remove(handler_index));
180 }
181 } else {
182 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
183 }
184 }
185 }
186 Either::Left((Err(error), _)) => {
187 log::warn!("received invalid RPC message: {}", error);
188 Err(error)?;
189 }
190 Either::Right(_) => return Ok(()),
191 }
192 }
193 }
194 }
195
196 pub async fn receive<M: EnvelopedMessage>(
197 self: &Arc<Self>,
198 connection_id: ConnectionId,
199 ) -> Result<TypedEnvelope<M>> {
200 let connection = self.connection(connection_id).await?;
201 let envelope = connection.reader.lock().await.read_message().await?;
202 let id = envelope.id;
203 let payload =
204 M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
205 Ok(TypedEnvelope {
206 id,
207 connection_id,
208 payload,
209 })
210 }
211
212 pub fn request<T: RequestMessage>(
213 self: &Arc<Self>,
214 connection_id: ConnectionId,
215 req: T,
216 ) -> impl Future<Output = Result<T::Response>> {
217 let this = self.clone();
218 let (tx, mut rx) = oneshot::channel();
219 async move {
220 let connection = this.connection(connection_id).await?;
221 let message_id = connection
222 .next_message_id
223 .fetch_add(1, atomic::Ordering::SeqCst);
224 connection
225 .response_channels
226 .lock()
227 .await
228 .insert(message_id, tx);
229 connection
230 .writer
231 .lock()
232 .await
233 .write_message(&req.into_envelope(message_id, None))
234 .await?;
235 let response = rx
236 .recv()
237 .await
238 .expect("response channel was unexpectedly dropped");
239 T::Response::from_envelope(response)
240 .ok_or_else(|| anyhow!("received response of the wrong type"))
241 }
242 }
243
244 pub fn send<T: EnvelopedMessage>(
245 self: &Arc<Self>,
246 connection_id: ConnectionId,
247 message: T,
248 ) -> impl Future<Output = Result<()>> {
249 let this = self.clone();
250 async move {
251 let connection = this.connection(connection_id).await?;
252 let message_id = connection
253 .next_message_id
254 .fetch_add(1, atomic::Ordering::SeqCst);
255 connection
256 .writer
257 .lock()
258 .await
259 .write_message(&message.into_envelope(message_id, None))
260 .await?;
261 Ok(())
262 }
263 }
264
265 pub fn respond<T: RequestMessage>(
266 self: &Arc<Self>,
267 request: TypedEnvelope<T>,
268 response: T::Response,
269 ) -> impl Future<Output = Result<()>> {
270 let this = self.clone();
271 async move {
272 let connection = this.connection(request.connection_id).await?;
273 let message_id = connection
274 .next_message_id
275 .fetch_add(1, atomic::Ordering::SeqCst);
276 connection
277 .writer
278 .lock()
279 .await
280 .write_message(&response.into_envelope(message_id, Some(request.id)))
281 .await?;
282 Ok(())
283 }
284 }
285
286 async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
287 Ok(self
288 .connections
289 .read()
290 .await
291 .get(&id)
292 .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
293 .clone())
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use smol::{
301 io::AsyncWriteExt,
302 net::unix::{UnixListener, UnixStream},
303 };
304 use std::io;
305 use tempdir::TempDir;
306
307 #[test]
308 fn test_request_response() {
309 smol::block_on(async move {
310 // create socket
311 let socket_dir_path = TempDir::new("test-request-response").unwrap();
312 let socket_path = socket_dir_path.path().join("test.sock");
313 let listener = UnixListener::bind(&socket_path).unwrap();
314
315 // create 2 clients connected to 1 server
316 let server = Peer::new();
317 let client1 = Peer::new();
318 let client2 = Peer::new();
319 let client1_conn_id = client1
320 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
321 .await;
322 let client2_conn_id = client2
323 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
324 .await;
325 let server_conn_id1 = server
326 .add_connection(listener.accept().await.unwrap().0)
327 .await;
328 let server_conn_id2 = server
329 .add_connection(listener.accept().await.unwrap().0)
330 .await;
331 smol::spawn(client1.handle_messages(client1_conn_id)).detach();
332 smol::spawn(client2.handle_messages(client2_conn_id)).detach();
333 smol::spawn(server.handle_messages(server_conn_id1)).detach();
334 smol::spawn(server.handle_messages(server_conn_id2)).detach();
335
336 // define the expected requests and responses
337 let request1 = proto::OpenWorktree {
338 worktree_id: 101,
339 access_token: "first-worktree-access-token".to_string(),
340 };
341 let response1 = proto::OpenWorktreeResponse {
342 worktree: Some(proto::Worktree {
343 paths: vec![b"path/one".to_vec()],
344 }),
345 };
346 let request2 = proto::OpenWorktree {
347 worktree_id: 102,
348 access_token: "second-worktree-access-token".to_string(),
349 };
350 let response2 = proto::OpenWorktreeResponse {
351 worktree: Some(proto::Worktree {
352 paths: vec![b"path/two".to_vec(), b"path/three".to_vec()],
353 }),
354 };
355 let request3 = proto::OpenBuffer {
356 worktree_id: 102,
357 path: b"path/two".to_vec(),
358 };
359 let response3 = proto::OpenBufferResponse {
360 buffer: Some(proto::Buffer {
361 id: 1001,
362 path: b"path/two".to_vec(),
363 content: b"path/two content".to_vec(),
364 history: vec![],
365 }),
366 };
367 let request4 = proto::OpenBuffer {
368 worktree_id: 101,
369 path: b"path/one".to_vec(),
370 };
371 let response4 = proto::OpenBufferResponse {
372 buffer: Some(proto::Buffer {
373 id: 1002,
374 path: b"path/one".to_vec(),
375 content: b"path/one content".to_vec(),
376 history: vec![],
377 }),
378 };
379
380 // on the server, respond to two requests for each client
381 let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
382 let mut open_worktree_rx = server.add_message_handler::<proto::OpenWorktree>().await;
383 let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
384 smol::spawn({
385 let request1 = request1.clone();
386 let request2 = request2.clone();
387 let request3 = request3.clone();
388 let request4 = request4.clone();
389 let response1 = response1.clone();
390 let response2 = response2.clone();
391 let response3 = response3.clone();
392 let response4 = response4.clone();
393 async move {
394 let msg = open_worktree_rx.recv().await.unwrap();
395 assert_eq!(msg.payload, request1);
396 server.respond(msg, response1.clone()).await.unwrap();
397
398 let msg = open_worktree_rx.recv().await.unwrap();
399 assert_eq!(msg.payload, request2.clone());
400 server.respond(msg, response2.clone()).await.unwrap();
401
402 let msg = open_buffer_rx.recv().await.unwrap();
403 assert_eq!(msg.payload, request3.clone());
404 server.respond(msg, response3.clone()).await.unwrap();
405
406 let msg = open_buffer_rx.recv().await.unwrap();
407 assert_eq!(msg.payload, request4.clone());
408 server.respond(msg, response4.clone()).await.unwrap();
409
410 server_done_tx.send(()).await.unwrap();
411 }
412 })
413 .detach();
414
415 assert_eq!(
416 client1.request(client1_conn_id, request1).await.unwrap(),
417 response1
418 );
419 assert_eq!(
420 client2.request(client2_conn_id, request2).await.unwrap(),
421 response2
422 );
423 assert_eq!(
424 client2.request(client2_conn_id, request3).await.unwrap(),
425 response3
426 );
427 assert_eq!(
428 client1.request(client1_conn_id, request4).await.unwrap(),
429 response4
430 );
431
432 client1.disconnect(client1_conn_id).await;
433 client2.disconnect(client1_conn_id).await;
434
435 server_done_rx.recv().await.unwrap();
436 });
437 }
438
439 #[test]
440 fn test_disconnect() {
441 smol::block_on(async move {
442 let socket_dir_path = TempDir::new("drop-client").unwrap();
443 let socket_path = socket_dir_path.path().join(".sock");
444 let listener = UnixListener::bind(&socket_path).unwrap();
445 let client_conn = UnixStream::connect(&socket_path).await.unwrap();
446 let (mut server_conn, _) = listener.accept().await.unwrap();
447
448 let client = Peer::new();
449 let connection_id = client.add_connection(client_conn).await;
450 let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
451 barrier::channel();
452 let handle_messages = client.handle_messages(connection_id);
453 smol::spawn(async move {
454 handle_messages.await.ok();
455 incoming_messages_ended_tx.send(()).await.unwrap();
456 })
457 .detach();
458 client.disconnect(connection_id).await;
459
460 incoming_messages_ended_rx.recv().await;
461
462 let err = server_conn.write(&[]).await.unwrap_err();
463 assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
464 });
465 }
466
467 #[test]
468 fn test_io_error() {
469 smol::block_on(async move {
470 let socket_dir_path = TempDir::new("io-error").unwrap();
471 let socket_path = socket_dir_path.path().join(".sock");
472 let _listener = UnixListener::bind(&socket_path).unwrap();
473 let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
474 client_conn.close().await.unwrap();
475
476 let client = Peer::new();
477 let connection_id = client.add_connection(client_conn).await;
478 smol::spawn(client.handle_messages(connection_id)).detach();
479
480 let err = client
481 .request(
482 connection_id,
483 proto::Auth {
484 user_id: 42,
485 access_token: "token".to_string(),
486 },
487 )
488 .await
489 .unwrap_err();
490 assert_eq!(
491 err.downcast_ref::<io::Error>().unwrap().kind(),
492 io::ErrorKind::BrokenPipe
493 );
494 });
495 }
496}