1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Result};
3use async_lock::{Mutex, RwLock};
4use futures::{
5 future::{BoxFuture, Either},
6 AsyncRead, AsyncWrite, FutureExt,
7};
8use postage::{
9 barrier, mpsc, oneshot,
10 prelude::{Sink, Stream},
11};
12use std::{
13 any::TypeId,
14 collections::{HashMap, HashSet},
15 future::Future,
16 pin::Pin,
17 sync::{
18 atomic::{self, AtomicU32},
19 Arc,
20 },
21};
22
23type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
26pub struct ConnectionId(u32);
27
28struct Connection {
29 writer: Mutex<MessageStream<BoxedWriter>>,
30 response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
31 next_message_id: AtomicU32,
32}
33
34type MessageHandler = Box<
35 dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
36>;
37
38pub struct TypedEnvelope<T> {
39 id: u32,
40 connection_id: ConnectionId,
41 payload: T,
42}
43
44impl<T> TypedEnvelope<T> {
45 pub fn connection_id(&self) -> ConnectionId {
46 self.connection_id
47 }
48
49 pub fn payload(&self) -> &T {
50 &self.payload
51 }
52}
53
54pub struct Peer {
55 connections: RwLock<HashMap<ConnectionId, (Arc<Connection>, barrier::Sender)>>,
56 message_handlers: RwLock<Vec<MessageHandler>>,
57 handler_types: Mutex<HashSet<TypeId>>,
58 next_connection_id: AtomicU32,
59}
60
61impl Peer {
62 pub fn new() -> Arc<Self> {
63 Arc::new(Self {
64 connections: Default::default(),
65 message_handlers: Default::default(),
66 handler_types: Default::default(),
67 next_connection_id: Default::default(),
68 })
69 }
70
71 pub async fn add_message_handler<T: EnvelopedMessage>(
72 &self,
73 ) -> mpsc::Receiver<TypedEnvelope<T>> {
74 if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
75 panic!("duplicate handler type");
76 }
77
78 let (tx, rx) = mpsc::channel(256);
79 self.message_handlers
80 .write()
81 .await
82 .push(Box::new(move |envelope, connection_id| {
83 if envelope.as_ref().map_or(false, T::matches_envelope) {
84 let envelope = Option::take(envelope).unwrap();
85 let mut tx = tx.clone();
86 Some(
87 async move {
88 tx.send(TypedEnvelope {
89 id: envelope.id,
90 connection_id,
91 payload: T::from_envelope(envelope).unwrap(),
92 })
93 .await
94 .is_err()
95 }
96 .boxed(),
97 )
98 } else {
99 None
100 }
101 }));
102 rx
103 }
104
105 pub async fn add_connection<Conn>(
106 self: &Arc<Self>,
107 conn: Conn,
108 ) -> (ConnectionId, impl Future<Output = Result<()>>)
109 where
110 Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
111 {
112 let connection_id = ConnectionId(
113 self.next_connection_id
114 .fetch_add(1, atomic::Ordering::SeqCst),
115 );
116 let (close_tx, mut close_rx) = barrier::channel();
117 let connection = Arc::new(Connection {
118 writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
119 response_channels: Default::default(),
120 next_message_id: Default::default(),
121 });
122
123 self.connections
124 .write()
125 .await
126 .insert(connection_id, (connection.clone(), close_tx));
127
128 let this = self.clone();
129 let handler_future = async move {
130 let closed = close_rx.recv();
131 futures::pin_mut!(closed);
132
133 let mut stream = MessageStream::new(conn);
134 loop {
135 let read_message = stream.read_message();
136 futures::pin_mut!(read_message);
137
138 match futures::future::select(read_message, &mut closed).await {
139 Either::Left((Ok(incoming), _)) => {
140 if let Some(responding_to) = incoming.responding_to {
141 let channel = connection
142 .response_channels
143 .lock()
144 .await
145 .remove(&responding_to);
146 if let Some(mut tx) = channel {
147 tx.send(incoming).await.ok();
148 } else {
149 log::warn!(
150 "received RPC response to unknown request {}",
151 responding_to
152 );
153 }
154 } else {
155 let mut envelope = Some(incoming);
156 let mut handler_index = None;
157 let mut handler_was_dropped = false;
158 for (i, handler) in
159 this.message_handlers.read().await.iter().enumerate()
160 {
161 if let Some(future) = handler(&mut envelope, connection_id) {
162 handler_was_dropped = future.await;
163 handler_index = Some(i);
164 break;
165 }
166 }
167
168 if let Some(handler_index) = handler_index {
169 if handler_was_dropped {
170 drop(this.message_handlers.write().await.remove(handler_index));
171 }
172 } else {
173 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
174 }
175 }
176 }
177 Either::Left((Err(error), _)) => {
178 log::warn!("received invalid RPC message: {}", error);
179 Err(error)?;
180 }
181 Either::Right(_) => return Ok(()),
182 }
183 }
184 };
185
186 (connection_id, handler_future)
187 }
188
189 pub async fn disconnect(&self, connection_id: ConnectionId) {
190 self.connections.write().await.remove(&connection_id);
191 }
192
193 pub fn request<T: RequestMessage>(
194 self: &Arc<Self>,
195 connection_id: ConnectionId,
196 req: T,
197 ) -> impl Future<Output = Result<T::Response>> {
198 let this = self.clone();
199 let (tx, mut rx) = oneshot::channel();
200 async move {
201 let connection = this.connection(connection_id).await?;
202 let message_id = connection
203 .next_message_id
204 .fetch_add(1, atomic::Ordering::SeqCst);
205 connection
206 .response_channels
207 .lock()
208 .await
209 .insert(message_id, tx);
210 connection
211 .writer
212 .lock()
213 .await
214 .write_message(&req.into_envelope(message_id, None))
215 .await?;
216 let response = rx
217 .recv()
218 .await
219 .expect("response channel was unexpectedly dropped");
220 T::Response::from_envelope(response)
221 .ok_or_else(|| anyhow!("received response of the wrong type"))
222 }
223 }
224
225 pub fn send<T: EnvelopedMessage>(
226 self: &Arc<Self>,
227 connection_id: ConnectionId,
228 message: T,
229 ) -> impl Future<Output = Result<()>> {
230 let this = self.clone();
231 async move {
232 let connection = this.connection(connection_id).await?;
233 let message_id = connection
234 .next_message_id
235 .fetch_add(1, atomic::Ordering::SeqCst);
236 connection
237 .writer
238 .lock()
239 .await
240 .write_message(&message.into_envelope(message_id, None))
241 .await?;
242 Ok(())
243 }
244 }
245
246 pub fn respond<T: RequestMessage>(
247 self: &Arc<Self>,
248 request: TypedEnvelope<T>,
249 response: T::Response,
250 ) -> impl Future<Output = Result<()>> {
251 let this = self.clone();
252 async move {
253 let connection = this.connection(request.connection_id).await?;
254 let message_id = connection
255 .next_message_id
256 .fetch_add(1, atomic::Ordering::SeqCst);
257 connection
258 .writer
259 .lock()
260 .await
261 .write_message(&response.into_envelope(message_id, Some(request.id)))
262 .await?;
263 Ok(())
264 }
265 }
266
267 async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
268 Ok(self
269 .connections
270 .read()
271 .await
272 .get(&id)
273 .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
274 .0
275 .clone())
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use smol::{
283 io::AsyncWriteExt,
284 net::unix::{UnixListener, UnixStream},
285 };
286 use std::io;
287 use tempdir::TempDir;
288
289 #[test]
290 fn test_request_response() {
291 smol::block_on(async move {
292 // create socket
293 let socket_dir_path = TempDir::new("test-request-response").unwrap();
294 let socket_path = socket_dir_path.path().join("test.sock");
295 let listener = UnixListener::bind(&socket_path).unwrap();
296
297 // create 2 clients connected to 1 server
298 let server = Peer::new();
299 let client1 = Peer::new();
300 let client2 = Peer::new();
301 let (client1_conn_id, f1) = client1
302 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
303 .await;
304 let (client2_conn_id, f2) = client2
305 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
306 .await;
307 let (_, f3) = server
308 .add_connection(listener.accept().await.unwrap().0)
309 .await;
310 let (_, f4) = server
311 .add_connection(listener.accept().await.unwrap().0)
312 .await;
313 smol::spawn(f1).detach();
314 smol::spawn(f2).detach();
315 smol::spawn(f3).detach();
316 smol::spawn(f4).detach();
317
318 // define the expected requests and responses
319 let request1 = proto::OpenWorktree {
320 worktree_id: 101,
321 access_token: "first-worktree-access-token".to_string(),
322 };
323 let response1 = proto::OpenWorktreeResponse {
324 worktree: Some(proto::Worktree {
325 paths: vec![b"path/one".to_vec()],
326 }),
327 };
328 let request2 = proto::OpenWorktree {
329 worktree_id: 102,
330 access_token: "second-worktree-access-token".to_string(),
331 };
332 let response2 = proto::OpenWorktreeResponse {
333 worktree: Some(proto::Worktree {
334 paths: vec![b"path/two".to_vec(), b"path/three".to_vec()],
335 }),
336 };
337 let request3 = proto::OpenBuffer {
338 worktree_id: 102,
339 path: b"path/two".to_vec(),
340 };
341 let response3 = proto::OpenBufferResponse {
342 buffer: Some(proto::Buffer {
343 id: 1001,
344 path: b"path/two".to_vec(),
345 content: b"path/two content".to_vec(),
346 history: vec![],
347 }),
348 };
349 let request4 = proto::OpenBuffer {
350 worktree_id: 101,
351 path: b"path/one".to_vec(),
352 };
353 let response4 = proto::OpenBufferResponse {
354 buffer: Some(proto::Buffer {
355 id: 1002,
356 path: b"path/one".to_vec(),
357 content: b"path/one content".to_vec(),
358 history: vec![],
359 }),
360 };
361
362 // on the server, respond to two requests for each client
363 let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
364 let mut open_worktree_rx = server.add_message_handler::<proto::OpenWorktree>().await;
365 let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
366 smol::spawn({
367 let request1 = request1.clone();
368 let request2 = request2.clone();
369 let request3 = request3.clone();
370 let request4 = request4.clone();
371 let response1 = response1.clone();
372 let response2 = response2.clone();
373 let response3 = response3.clone();
374 let response4 = response4.clone();
375 async move {
376 let msg = open_worktree_rx.recv().await.unwrap();
377 assert_eq!(msg.payload, request1);
378 server.respond(msg, response1.clone()).await.unwrap();
379
380 let msg = open_worktree_rx.recv().await.unwrap();
381 assert_eq!(msg.payload, request2.clone());
382 server.respond(msg, response2.clone()).await.unwrap();
383
384 let msg = open_buffer_rx.recv().await.unwrap();
385 assert_eq!(msg.payload, request3.clone());
386 server.respond(msg, response3.clone()).await.unwrap();
387
388 let msg = open_buffer_rx.recv().await.unwrap();
389 assert_eq!(msg.payload, request4.clone());
390 server.respond(msg, response4.clone()).await.unwrap();
391
392 server_done_tx.send(()).await.unwrap();
393 }
394 })
395 .detach();
396
397 assert_eq!(
398 client1.request(client1_conn_id, request1).await.unwrap(),
399 response1
400 );
401 assert_eq!(
402 client2.request(client2_conn_id, request2).await.unwrap(),
403 response2
404 );
405 assert_eq!(
406 client2.request(client2_conn_id, request3).await.unwrap(),
407 response3
408 );
409 assert_eq!(
410 client1.request(client1_conn_id, request4).await.unwrap(),
411 response4
412 );
413
414 client1.disconnect(client1_conn_id).await;
415 client2.disconnect(client1_conn_id).await;
416
417 server_done_rx.recv().await.unwrap();
418 });
419 }
420
421 #[test]
422 fn test_disconnect() {
423 smol::block_on(async move {
424 let socket_dir_path = TempDir::new("drop-client").unwrap();
425 let socket_path = socket_dir_path.path().join(".sock");
426 let listener = UnixListener::bind(&socket_path).unwrap();
427 let client_conn = UnixStream::connect(&socket_path).await.unwrap();
428 let (mut server_conn, _) = listener.accept().await.unwrap();
429
430 let client = Peer::new();
431 let (connection_id, handler) = client.add_connection(client_conn).await;
432 smol::spawn(handler).detach();
433 client.disconnect(connection_id).await;
434
435 // Try sending an empty payload over and over, until the client is dropped and hangs up.
436 loop {
437 match server_conn.write(&[]).await {
438 Ok(_) => {}
439 Err(err) => {
440 if err.kind() == io::ErrorKind::BrokenPipe {
441 break;
442 }
443 }
444 }
445 }
446 });
447 }
448
449 #[test]
450 fn test_io_error() {
451 smol::block_on(async move {
452 let socket_dir_path = TempDir::new("io-error").unwrap();
453 let socket_path = socket_dir_path.path().join(".sock");
454 let _listener = UnixListener::bind(&socket_path).unwrap();
455 let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
456 client_conn.close().await.unwrap();
457
458 let client = Peer::new();
459 let (connection_id, handler) = client.add_connection(client_conn).await;
460 smol::spawn(handler).detach();
461
462 let err = client
463 .request(
464 connection_id,
465 proto::Auth {
466 user_id: 42,
467 access_token: "token".to_string(),
468 },
469 )
470 .await
471 .unwrap_err();
472 assert_eq!(
473 err.downcast_ref::<io::Error>().unwrap().kind(),
474 io::ErrorKind::BrokenPipe
475 );
476 });
477 }
478}