protocol.rs

 1use anyhow::Result;
 2use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
 3use prost::Message as _;
 4use rpc::proto::Envelope;
 5
 6#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
 7pub struct MessageId(pub u32);
 8
 9pub type MessageLen = u32;
10pub const MESSAGE_LEN_SIZE: usize = size_of::<MessageLen>();
11
12pub fn message_len_from_buffer(buffer: &[u8]) -> MessageLen {
13    MessageLen::from_le_bytes(buffer.try_into().unwrap())
14}
15
16pub async fn read_message_with_len<S: AsyncRead + Unpin>(
17    stream: &mut S,
18    buffer: &mut Vec<u8>,
19    message_len: MessageLen,
20) -> Result<Envelope> {
21    buffer.resize(message_len as usize, 0);
22    stream.read_exact(buffer).await?;
23    Ok(Envelope::decode(buffer.as_slice())?)
24}
25
26pub async fn read_message<S: AsyncRead + Unpin>(
27    stream: &mut S,
28    buffer: &mut Vec<u8>,
29) -> Result<Envelope> {
30    buffer.resize(MESSAGE_LEN_SIZE, 0);
31    stream.read_exact(buffer).await?;
32
33    let len = message_len_from_buffer(buffer);
34
35    read_message_with_len(stream, buffer, len).await
36}
37
38pub async fn write_message<S: AsyncWrite + Unpin>(
39    stream: &mut S,
40    buffer: &mut Vec<u8>,
41    message: Envelope,
42) -> Result<()> {
43    let message_len = message.encoded_len() as u32;
44    stream
45        .write_all(message_len.to_le_bytes().as_slice())
46        .await?;
47    buffer.clear();
48    buffer.reserve(message_len as usize);
49    message.encode(buffer)?;
50    stream.write_all(buffer).await?;
51    Ok(())
52}
53
54pub async fn read_message_raw<S: AsyncRead + Unpin>(
55    stream: &mut S,
56    buffer: &mut Vec<u8>,
57) -> Result<()> {
58    buffer.resize(MESSAGE_LEN_SIZE, 0);
59    stream.read_exact(buffer).await?;
60
61    let message_len = message_len_from_buffer(buffer);
62    buffer.resize(message_len as usize, 0);
63    stream.read_exact(buffer).await?;
64
65    Ok(())
66}