protocol.rs

 1use anyhow::Result;
 2use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
 3use prost::Message as _;
 4use rpc::proto::Envelope;
 5
 6#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
 7pub struct MessageId(pub u32);
 8
 9pub type MessageLen = u32;
10pub const MESSAGE_LEN_SIZE: usize = size_of::<MessageLen>();
11
12pub fn message_len_from_buffer(buffer: &[u8]) -> MessageLen {
13    MessageLen::from_le_bytes(buffer.try_into().unwrap())
14}
15
16pub async fn read_message_with_len<S: AsyncRead + Unpin>(
17    stream: &mut S,
18    buffer: &mut Vec<u8>,
19    message_len: MessageLen,
20) -> Result<Envelope> {
21    buffer.resize(message_len as usize, 0);
22    stream.read_exact(buffer).await?;
23    Ok(Envelope::decode(buffer.as_slice())?)
24}
25
26pub async fn read_message<S: AsyncRead + Unpin>(
27    stream: &mut S,
28    buffer: &mut Vec<u8>,
29) -> Result<Envelope> {
30    buffer.resize(MESSAGE_LEN_SIZE, 0);
31    stream.read_exact(buffer).await?;
32
33    let len = message_len_from_buffer(buffer);
34
35    read_message_with_len(stream, buffer, len).await
36}
37
38pub async fn write_message<S: AsyncWrite + Unpin>(
39    stream: &mut S,
40    buffer: &mut Vec<u8>,
41    message: Envelope,
42) -> Result<()> {
43    let message_len = message.encoded_len() as u32;
44    stream
45        .write_all(message_len.to_le_bytes().as_slice())
46        .await?;
47    buffer.clear();
48    buffer.reserve(message_len as usize);
49    message.encode(buffer)?;
50    stream.write_all(buffer).await?;
51    Ok(())
52}
53
54pub async fn write_size_prefixed_buffer<S: AsyncWrite + Unpin>(
55    stream: &mut S,
56    buffer: &mut Vec<u8>,
57) -> Result<()> {
58    let len = buffer.len() as u32;
59    stream.write_all(len.to_le_bytes().as_slice()).await?;
60    stream.write_all(buffer).await?;
61    Ok(())
62}
63
64pub async fn read_message_raw<S: AsyncRead + Unpin>(
65    stream: &mut S,
66    buffer: &mut Vec<u8>,
67) -> Result<()> {
68    buffer.resize(MESSAGE_LEN_SIZE, 0);
69    stream.read_exact(buffer).await?;
70
71    let message_len = message_len_from_buffer(buffer);
72    buffer.resize(message_len as usize, 0);
73    stream.read_exact(buffer).await?;
74
75    Ok(())
76}