1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Result};
3use async_lock::{Mutex, RwLock};
4use futures::{
5 future::{BoxFuture, Either},
6 AsyncRead, AsyncWrite, FutureExt,
7};
8use postage::{
9 barrier, mpsc, oneshot,
10 prelude::{Sink, Stream},
11};
12use std::{
13 any::TypeId,
14 collections::{HashMap, HashSet},
15 future::Future,
16 pin::Pin,
17 sync::{
18 atomic::{self, AtomicU32},
19 Arc,
20 },
21};
22
23type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
24type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
25
26#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
27pub struct ConnectionId(u32);
28
29struct Connection {
30 writer: Mutex<MessageStream<BoxedWriter>>,
31 reader: Mutex<MessageStream<BoxedReader>>,
32 response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
33 next_message_id: AtomicU32,
34}
35
36type MessageHandler = Box<
37 dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
38>;
39
40pub struct TypedEnvelope<T> {
41 id: u32,
42 connection_id: ConnectionId,
43 payload: T,
44}
45
46impl<T> TypedEnvelope<T> {
47 pub fn connection_id(&self) -> ConnectionId {
48 self.connection_id
49 }
50
51 pub fn payload(&self) -> &T {
52 &self.payload
53 }
54}
55
56pub struct Peer {
57 connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
58 connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
59 message_handlers: RwLock<Vec<MessageHandler>>,
60 handler_types: Mutex<HashSet<TypeId>>,
61 next_connection_id: AtomicU32,
62}
63
64impl Peer {
65 pub fn new() -> Arc<Self> {
66 Arc::new(Self {
67 connections: Default::default(),
68 connection_close_barriers: Default::default(),
69 message_handlers: Default::default(),
70 handler_types: Default::default(),
71 next_connection_id: Default::default(),
72 })
73 }
74
75 pub async fn add_message_handler<T: EnvelopedMessage>(
76 &self,
77 ) -> mpsc::Receiver<TypedEnvelope<T>> {
78 if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
79 panic!("duplicate handler type");
80 }
81
82 let (tx, rx) = mpsc::channel(256);
83 self.message_handlers
84 .write()
85 .await
86 .push(Box::new(move |envelope, connection_id| {
87 if envelope.as_ref().map_or(false, T::matches_envelope) {
88 let envelope = Option::take(envelope).unwrap();
89 let mut tx = tx.clone();
90 Some(
91 async move {
92 tx.send(TypedEnvelope {
93 id: envelope.id,
94 connection_id,
95 payload: T::from_envelope(envelope).unwrap(),
96 })
97 .await
98 .is_err()
99 }
100 .boxed(),
101 )
102 } else {
103 None
104 }
105 }));
106 rx
107 }
108
109 pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
110 where
111 Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
112 {
113 let connection_id = ConnectionId(
114 self.next_connection_id
115 .fetch_add(1, atomic::Ordering::SeqCst),
116 );
117 self.connections.write().await.insert(
118 connection_id,
119 Arc::new(Connection {
120 reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
121 writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
122 response_channels: Default::default(),
123 next_message_id: Default::default(),
124 }),
125 );
126 connection_id
127 }
128
129 pub async fn disconnect(&self, connection_id: ConnectionId) {
130 self.connections.write().await.remove(&connection_id);
131 self.connection_close_barriers
132 .write()
133 .await
134 .remove(&connection_id);
135 }
136
137 pub fn handle_messages(
138 self: &Arc<Self>,
139 connection_id: ConnectionId,
140 ) -> impl Future<Output = Result<()>> + 'static {
141 let (close_tx, mut close_rx) = barrier::channel();
142 let this = self.clone();
143 async move {
144 this.connection_close_barriers
145 .write()
146 .await
147 .insert(connection_id, close_tx);
148 let connection = this.connection(connection_id).await?;
149 let closed = close_rx.recv();
150 futures::pin_mut!(closed);
151
152 loop {
153 let mut reader = connection.reader.lock().await;
154 let read_message = reader.read_message();
155 futures::pin_mut!(read_message);
156
157 match futures::future::select(read_message, &mut closed).await {
158 Either::Left((Ok(incoming), _)) => {
159 if let Some(responding_to) = incoming.responding_to {
160 let channel = connection
161 .response_channels
162 .lock()
163 .await
164 .remove(&responding_to);
165 if let Some(mut tx) = channel {
166 tx.send(incoming).await.ok();
167 } else {
168 log::warn!(
169 "received RPC response to unknown request {}",
170 responding_to
171 );
172 }
173 } else {
174 let mut envelope = Some(incoming);
175 let mut handler_index = None;
176 let mut handler_was_dropped = false;
177 for (i, handler) in
178 this.message_handlers.read().await.iter().enumerate()
179 {
180 if let Some(future) = handler(&mut envelope, connection_id) {
181 handler_was_dropped = future.await;
182 handler_index = Some(i);
183 break;
184 }
185 }
186
187 if let Some(handler_index) = handler_index {
188 if handler_was_dropped {
189 drop(this.message_handlers.write().await.remove(handler_index));
190 }
191 } else {
192 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
193 }
194 }
195 }
196 Either::Left((Err(error), _)) => {
197 log::warn!("received invalid RPC message: {}", error);
198 Err(error)?;
199 }
200 Either::Right(_) => return Ok(()),
201 }
202 }
203 }
204 }
205
206 pub async fn receive<M: EnvelopedMessage>(
207 self: &Arc<Self>,
208 connection_id: ConnectionId,
209 ) -> Result<TypedEnvelope<M>> {
210 let connection = self.connection(connection_id).await?;
211 let envelope = connection.reader.lock().await.read_message().await?;
212 let id = envelope.id;
213 let payload =
214 M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
215 Ok(TypedEnvelope {
216 id,
217 connection_id,
218 payload,
219 })
220 }
221
222 pub fn request<T: RequestMessage>(
223 self: &Arc<Self>,
224 connection_id: ConnectionId,
225 req: T,
226 ) -> impl Future<Output = Result<T::Response>> {
227 let this = self.clone();
228 let (tx, mut rx) = oneshot::channel();
229 async move {
230 let connection = this.connection(connection_id).await?;
231 let message_id = connection
232 .next_message_id
233 .fetch_add(1, atomic::Ordering::SeqCst);
234 connection
235 .response_channels
236 .lock()
237 .await
238 .insert(message_id, tx);
239 connection
240 .writer
241 .lock()
242 .await
243 .write_message(&req.into_envelope(message_id, None))
244 .await?;
245 let response = rx
246 .recv()
247 .await
248 .expect("response channel was unexpectedly dropped");
249 T::Response::from_envelope(response)
250 .ok_or_else(|| anyhow!("received response of the wrong type"))
251 }
252 }
253
254 pub fn send<T: EnvelopedMessage>(
255 self: &Arc<Self>,
256 connection_id: ConnectionId,
257 message: T,
258 ) -> impl Future<Output = Result<()>> {
259 let this = self.clone();
260 async move {
261 let connection = this.connection(connection_id).await?;
262 let message_id = connection
263 .next_message_id
264 .fetch_add(1, atomic::Ordering::SeqCst);
265 connection
266 .writer
267 .lock()
268 .await
269 .write_message(&message.into_envelope(message_id, None))
270 .await?;
271 Ok(())
272 }
273 }
274
275 pub fn respond<T: RequestMessage>(
276 self: &Arc<Self>,
277 request: TypedEnvelope<T>,
278 response: T::Response,
279 ) -> impl Future<Output = Result<()>> {
280 let this = self.clone();
281 async move {
282 let connection = this.connection(request.connection_id).await?;
283 let message_id = connection
284 .next_message_id
285 .fetch_add(1, atomic::Ordering::SeqCst);
286 connection
287 .writer
288 .lock()
289 .await
290 .write_message(&response.into_envelope(message_id, Some(request.id)))
291 .await?;
292 Ok(())
293 }
294 }
295
296 async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
297 Ok(self
298 .connections
299 .read()
300 .await
301 .get(&id)
302 .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
303 .clone())
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use smol::{
311 io::AsyncWriteExt,
312 net::unix::{UnixListener, UnixStream},
313 };
314 use std::io;
315 use tempdir::TempDir;
316
317 #[test]
318 fn test_request_response() {
319 smol::block_on(async move {
320 // create socket
321 let socket_dir_path = TempDir::new("test-request-response").unwrap();
322 let socket_path = socket_dir_path.path().join("test.sock");
323 let listener = UnixListener::bind(&socket_path).unwrap();
324
325 // create 2 clients connected to 1 server
326 let server = Peer::new();
327 let client1 = Peer::new();
328 let client2 = Peer::new();
329 let client1_conn_id = client1
330 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
331 .await;
332 let client2_conn_id = client2
333 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
334 .await;
335 let server_conn_id1 = server
336 .add_connection(listener.accept().await.unwrap().0)
337 .await;
338 let server_conn_id2 = server
339 .add_connection(listener.accept().await.unwrap().0)
340 .await;
341 smol::spawn(client1.handle_messages(client1_conn_id)).detach();
342 smol::spawn(client2.handle_messages(client2_conn_id)).detach();
343 smol::spawn(server.handle_messages(server_conn_id1)).detach();
344 smol::spawn(server.handle_messages(server_conn_id2)).detach();
345
346 // define the expected requests and responses
347 let request1 = proto::OpenWorktree {
348 worktree_id: 101,
349 access_token: "first-worktree-access-token".to_string(),
350 };
351 let response1 = proto::OpenWorktreeResponse {
352 worktree: Some(proto::Worktree {
353 paths: vec![b"path/one".to_vec()],
354 }),
355 };
356 let request2 = proto::OpenWorktree {
357 worktree_id: 102,
358 access_token: "second-worktree-access-token".to_string(),
359 };
360 let response2 = proto::OpenWorktreeResponse {
361 worktree: Some(proto::Worktree {
362 paths: vec![b"path/two".to_vec(), b"path/three".to_vec()],
363 }),
364 };
365 let request3 = proto::OpenBuffer {
366 worktree_id: 102,
367 path: b"path/two".to_vec(),
368 };
369 let response3 = proto::OpenBufferResponse {
370 buffer: Some(proto::Buffer {
371 id: 1001,
372 path: b"path/two".to_vec(),
373 content: b"path/two content".to_vec(),
374 history: vec![],
375 }),
376 };
377 let request4 = proto::OpenBuffer {
378 worktree_id: 101,
379 path: b"path/one".to_vec(),
380 };
381 let response4 = proto::OpenBufferResponse {
382 buffer: Some(proto::Buffer {
383 id: 1002,
384 path: b"path/one".to_vec(),
385 content: b"path/one content".to_vec(),
386 history: vec![],
387 }),
388 };
389
390 // on the server, respond to two requests for each client
391 let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
392 let mut open_worktree_rx = server.add_message_handler::<proto::OpenWorktree>().await;
393 let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
394 smol::spawn({
395 let request1 = request1.clone();
396 let request2 = request2.clone();
397 let request3 = request3.clone();
398 let request4 = request4.clone();
399 let response1 = response1.clone();
400 let response2 = response2.clone();
401 let response3 = response3.clone();
402 let response4 = response4.clone();
403 async move {
404 let msg = open_worktree_rx.recv().await.unwrap();
405 assert_eq!(msg.payload, request1);
406 server.respond(msg, response1.clone()).await.unwrap();
407
408 let msg = open_worktree_rx.recv().await.unwrap();
409 assert_eq!(msg.payload, request2.clone());
410 server.respond(msg, response2.clone()).await.unwrap();
411
412 let msg = open_buffer_rx.recv().await.unwrap();
413 assert_eq!(msg.payload, request3.clone());
414 server.respond(msg, response3.clone()).await.unwrap();
415
416 let msg = open_buffer_rx.recv().await.unwrap();
417 assert_eq!(msg.payload, request4.clone());
418 server.respond(msg, response4.clone()).await.unwrap();
419
420 server_done_tx.send(()).await.unwrap();
421 }
422 })
423 .detach();
424
425 assert_eq!(
426 client1.request(client1_conn_id, request1).await.unwrap(),
427 response1
428 );
429 assert_eq!(
430 client2.request(client2_conn_id, request2).await.unwrap(),
431 response2
432 );
433 assert_eq!(
434 client2.request(client2_conn_id, request3).await.unwrap(),
435 response3
436 );
437 assert_eq!(
438 client1.request(client1_conn_id, request4).await.unwrap(),
439 response4
440 );
441
442 client1.disconnect(client1_conn_id).await;
443 client2.disconnect(client1_conn_id).await;
444
445 server_done_rx.recv().await.unwrap();
446 });
447 }
448
449 #[test]
450 fn test_disconnect() {
451 smol::block_on(async move {
452 let socket_dir_path = TempDir::new("drop-client").unwrap();
453 let socket_path = socket_dir_path.path().join(".sock");
454 let listener = UnixListener::bind(&socket_path).unwrap();
455 let client_conn = UnixStream::connect(&socket_path).await.unwrap();
456 let (mut server_conn, _) = listener.accept().await.unwrap();
457
458 let client = Peer::new();
459 let connection_id = client.add_connection(client_conn).await;
460 let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
461 barrier::channel();
462 let handle_messages = client.handle_messages(connection_id);
463 smol::spawn(async move {
464 handle_messages.await.ok();
465 incoming_messages_ended_tx.send(()).await.unwrap();
466 })
467 .detach();
468 client.disconnect(connection_id).await;
469
470 incoming_messages_ended_rx.recv().await;
471
472 let err = server_conn.write(&[]).await.unwrap_err();
473 assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
474 });
475 }
476
477 #[test]
478 fn test_io_error() {
479 smol::block_on(async move {
480 let socket_dir_path = TempDir::new("io-error").unwrap();
481 let socket_path = socket_dir_path.path().join(".sock");
482 let _listener = UnixListener::bind(&socket_path).unwrap();
483 let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
484 client_conn.close().await.unwrap();
485
486 let client = Peer::new();
487 let connection_id = client.add_connection(client_conn).await;
488 smol::spawn(client.handle_messages(connection_id)).detach();
489
490 let err = client
491 .request(
492 connection_id,
493 proto::Auth {
494 user_id: 42,
495 access_token: "token".to_string(),
496 },
497 )
498 .await
499 .unwrap_err();
500 assert_eq!(
501 err.downcast_ref::<io::Error>().unwrap().kind(),
502 io::ErrorKind::BrokenPipe
503 );
504 });
505 }
506}