1use anyhow::Result;
2use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3use prost::Message as _;
4use rpc::proto::Envelope;
5use std::mem::size_of;
6
7#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
8pub struct MessageId(pub u32);
9
10pub type MessageLen = u32;
11pub const MESSAGE_LEN_SIZE: usize = size_of::<MessageLen>();
12
13pub fn message_len_from_buffer(buffer: &[u8]) -> MessageLen {
14 MessageLen::from_le_bytes(buffer.try_into().unwrap())
15}
16
17pub async fn read_message_with_len<S: AsyncRead + Unpin>(
18 stream: &mut S,
19 buffer: &mut Vec<u8>,
20 message_len: MessageLen,
21) -> Result<Envelope> {
22 buffer.resize(message_len as usize, 0);
23 stream.read_exact(buffer).await?;
24 Ok(Envelope::decode(buffer.as_slice())?)
25}
26
27pub async fn read_message<S: AsyncRead + Unpin>(
28 stream: &mut S,
29 buffer: &mut Vec<u8>,
30) -> Result<Envelope> {
31 buffer.resize(MESSAGE_LEN_SIZE, 0);
32 stream.read_exact(buffer).await?;
33 let len = message_len_from_buffer(buffer);
34 read_message_with_len(stream, buffer, len).await
35}
36
37pub async fn write_message<S: AsyncWrite + Unpin>(
38 stream: &mut S,
39 buffer: &mut Vec<u8>,
40 message: Envelope,
41) -> Result<()> {
42 let message_len = message.encoded_len() as u32;
43 stream
44 .write_all(message_len.to_le_bytes().as_slice())
45 .await?;
46 buffer.clear();
47 buffer.reserve(message_len as usize);
48 message.encode(buffer)?;
49 stream.write_all(buffer).await?;
50 Ok(())
51}