protocol.rs

 1use anyhow::Result;
 2use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
 3use prost::Message as _;
 4use rpc::proto::Envelope;
 5use std::mem::size_of;
 6
 7#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
 8pub struct MessageId(pub u32);
 9
10pub type MessageLen = u32;
11pub const MESSAGE_LEN_SIZE: usize = size_of::<MessageLen>();
12
13pub fn message_len_from_buffer(buffer: &[u8]) -> MessageLen {
14    MessageLen::from_le_bytes(buffer.try_into().unwrap())
15}
16
17pub async fn read_message_with_len<S: AsyncRead + Unpin>(
18    stream: &mut S,
19    buffer: &mut Vec<u8>,
20    message_len: MessageLen,
21) -> Result<Envelope> {
22    buffer.resize(message_len as usize, 0);
23    stream.read_exact(buffer).await?;
24    Ok(Envelope::decode(buffer.as_slice())?)
25}
26
27pub async fn read_message<S: AsyncRead + Unpin>(
28    stream: &mut S,
29    buffer: &mut Vec<u8>,
30) -> Result<Envelope> {
31    buffer.resize(MESSAGE_LEN_SIZE, 0);
32    stream.read_exact(buffer).await?;
33    let len = message_len_from_buffer(buffer);
34    read_message_with_len(stream, buffer, len).await
35}
36
37pub async fn write_message<S: AsyncWrite + Unpin>(
38    stream: &mut S,
39    buffer: &mut Vec<u8>,
40    message: Envelope,
41) -> Result<()> {
42    let message_len = message.encoded_len() as u32;
43    stream
44        .write_all(message_len.to_le_bytes().as_slice())
45        .await?;
46    buffer.clear();
47    buffer.reserve(message_len as usize);
48    message.encode(buffer)?;
49    stream.write_all(buffer).await?;
50    Ok(())
51}