@@ -5,42 +5,28 @@ use std::{convert::TryInto, io};
include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
-/// A message that the client can send to the server.
-pub trait ClientMessage: Sized {
- fn to_variant(self) -> from_client::Variant;
- fn from_variant(variant: from_client::Variant) -> Option<Self>;
+pub trait EnvelopedMessage: Sized {
+ fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope;
+ fn from_envelope(envelope: Envelope) -> Option<Self>;
}
-/// A message that the server can send to the client.
-pub trait ServerMessage: Sized {
- fn to_variant(self) -> from_server::Variant;
- fn from_variant(variant: from_server::Variant) -> Option<Self>;
+pub trait RequestMessage: EnvelopedMessage {
+ type Response: EnvelopedMessage;
}
-/// A message that the client can send to the server, where the server must respond with a single
-/// message of a certain type.
-pub trait RequestMessage: ClientMessage {
- type Response: ServerMessage;
-}
-
-/// A message that the client can send to the server, where the server must respond with a series of
-/// messages of a certain type.
-pub trait SubscribeMessage: ClientMessage {
- type Event: ServerMessage;
-}
-
-/// A message that the client can send to the server, where the server will not respond.
-pub trait SendMessage: ClientMessage {}
-
-macro_rules! directed_message {
- ($name:ident, $direction_trait:ident, $direction_module:ident) => {
- impl $direction_trait for $direction_module::$name {
- fn to_variant(self) -> $direction_module::Variant {
- $direction_module::Variant::$name(self)
+macro_rules! message {
+ ($name:ident) => {
+ impl EnvelopedMessage for $name {
+ fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope {
+ Envelope {
+ id,
+ responding_to,
+ payload: Some(envelope::Payload::$name(self)),
+ }
}
- fn from_variant(variant: $direction_module::Variant) -> Option<Self> {
- if let $direction_module::Variant::$name(msg) = variant {
+ fn from_envelope(envelope: Envelope) -> Option<Self> {
+ if let Some(envelope::Payload::$name(msg)) = envelope.payload {
Some(msg)
} else {
None
@@ -52,36 +38,18 @@ macro_rules! directed_message {
macro_rules! request_message {
($req:ident, $resp:ident) => {
- directed_message!($req, ClientMessage, from_client);
- directed_message!($resp, ServerMessage, from_server);
- impl RequestMessage for from_client::$req {
- type Response = from_server::$resp;
- }
- };
-}
-
-macro_rules! send_message {
- ($msg:ident) => {
- directed_message!($msg, ClientMessage, from_client);
- impl SendMessage for from_client::$msg {}
- };
-}
-
-macro_rules! subscribe_message {
- ($subscription:ident, $event:ident) => {
- directed_message!($subscription, ClientMessage, from_client);
- directed_message!($event, ServerMessage, from_server);
- impl SubscribeMessage for from_client::$subscription {
- type Event = from_server::$event;
+ message!($req);
+ message!($resp);
+ impl RequestMessage for $req {
+ type Response = $resp;
}
};
}
request_message!(Auth, AuthResponse);
-request_message!(NewWorktree, NewWorktreeResponse);
request_message!(ShareWorktree, ShareWorktreeResponse);
-send_message!(UploadFile);
-subscribe_message!(SubscribeToPathRequests, PathRequest);
+request_message!(OpenWorktree, OpenWorktreeResponse);
+request_message!(OpenBuffer, OpenBufferResponse);
/// A stream of protobuf messages.
pub struct MessageStream<T> {
@@ -107,7 +75,7 @@ where
T: AsyncWrite + Unpin,
{
/// Write a given protobuf message to the stream.
- pub async fn write_message(&mut self, message: &impl Message) -> io::Result<()> {
+ pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
let message_len: u32 = message
.encoded_len()
.try_into()
@@ -124,13 +92,13 @@ where
T: AsyncRead + Unpin,
{
/// Read a protobuf message of the given type from the stream.
- pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
+ pub async fn read_message(&mut self) -> futures_io::Result<Envelope> {
let mut delimiter_buf = [0; 4];
self.byte_stream.read_exact(&mut delimiter_buf).await?;
let message_len = u32::from_be_bytes(delimiter_buf) as usize;
self.buffer.resize(message_len, 0);
self.byte_stream.read_exact(&mut self.buffer).await?;
- Ok(M::decode(self.buffer.as_slice())?)
+ Ok(Envelope::decode(self.buffer.as_slice())?)
}
}
@@ -151,30 +119,24 @@ mod tests {
chunk_size: 3,
};
- let message1 = FromClient {
- id: 3,
- variant: Some(from_client::Variant::Auth(from_client::Auth {
- user_id: 5,
- access_token: "the-access-token".into(),
- })),
- };
- let message2 = FromClient {
- id: 4,
- variant: Some(from_client::Variant::UploadFile(from_client::UploadFile {
- path: Vec::new(),
- content: format!(
- "a {}long error message that requires a two-byte length delimiter",
- "very ".repeat(60)
- )
- .into(),
- })),
- };
+ let message1 = Auth {
+ user_id: 5,
+ access_token: "the-access-token".into(),
+ }
+ .into_envelope(3, None);
+
+ let message2 = ShareWorktree {
+ worktree: Some(Worktree {
+ paths: vec![b"ok".to_vec()],
+ }),
+ }
+ .into_envelope(5, None);
let mut message_stream = MessageStream::new(byte_stream);
message_stream.write_message(&message1).await.unwrap();
message_stream.write_message(&message2).await.unwrap();
- let decoded_message1 = message_stream.read_message::<FromClient>().await.unwrap();
- let decoded_message2 = message_stream.read_message::<FromClient>().await.unwrap();
+ let decoded_message1 = message_stream.read_message().await.unwrap();
+ let decoded_message2 = message_stream.read_message().await.unwrap();
assert_eq!(decoded_message1, message1);
assert_eq!(decoded_message2, message2);
});
@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use futures::future::Either;
use gpui::executor::Background;
use postage::{
- barrier, mpsc,
+ barrier, oneshot,
prelude::{Sink, Stream},
};
use smol::{
@@ -14,18 +14,16 @@ use std::{
collections::HashMap,
future::Future,
sync::{
- atomic::{self, AtomicI32},
+ atomic::{self, AtomicU32},
Arc,
},
};
-use zed_rpc::proto::{
- self, MessageStream, RequestMessage, SendMessage, ServerMessage, SubscribeMessage,
-};
+use zed_rpc::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
pub struct RpcClient {
- response_channels: Arc<Mutex<HashMap<i32, (mpsc::Sender<proto::from_server::Variant>, bool)>>>,
+ response_channels: Arc<Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>>,
outgoing: Mutex<MessageStream<BoxedWriter>>,
- next_message_id: AtomicI32,
+ next_message_id: AtomicU32,
_drop_tx: barrier::Sender,
}
@@ -50,16 +48,14 @@ impl RpcClient {
response_channels,
outgoing: Mutex::new(MessageStream::new(Box::pin(conn_tx))),
_drop_tx,
- next_message_id: AtomicI32::new(0),
+ next_message_id: AtomicU32::new(0),
})
}
async fn handle_incoming<Conn>(
conn: ReadHalf<Conn>,
mut drop_rx: barrier::Receiver,
- response_channels: Arc<
- Mutex<HashMap<i32, (mpsc::Sender<proto::from_server::Variant>, bool)>>,
- >,
+ response_channels: Arc<Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>>,
) where
Conn: AsyncRead + Unpin,
{
@@ -68,36 +64,27 @@ impl RpcClient {
let mut stream = MessageStream::new(conn);
loop {
- let read_message = stream.read_message::<proto::FromServer>();
+ let read_message = stream.read_message();
smol::pin!(read_message);
match futures::future::select(read_message, &mut dropped).await {
Either::Left((Ok(incoming), _)) => {
- if let Some(variant) = incoming.variant {
- if let Some(request_id) = incoming.request_id {
- let channel = response_channels.lock().await.remove(&request_id);
- if let Some((mut tx, oneshot)) = channel {
- if tx.send(variant).await.is_ok() {
- if !oneshot {
- response_channels
- .lock()
- .await
- .insert(request_id, (tx, false));
- }
- }
- } else {
- log::warn!(
- "received RPC response to unknown request id {}",
- request_id
- );
- }
+ if let Some(responding_to) = incoming.responding_to {
+ let channel = response_channels.lock().await.remove(&responding_to);
+ if let Some(mut tx) = channel {
+ tx.send(incoming).await.ok();
+ } else {
+ log::warn!(
+ "received RPC response to unknown request {}",
+ responding_to
+ );
}
} else {
- log::warn!("received RPC message with no content");
+ // unprompted message from server
}
}
Either::Left((Err(error), _)) => {
- log::warn!("invalid incoming RPC message {:?}", error);
+ log::warn!("received invalid RPC message {:?}", error);
}
Either::Right(_) => break,
}
@@ -111,67 +98,35 @@ impl RpcClient {
let this = self.clone();
async move {
let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
- let (tx, mut rx) = mpsc::channel(1);
- this.response_channels
- .lock()
- .await
- .insert(message_id, (tx, true));
+ let (tx, mut rx) = oneshot::channel();
+ this.response_channels.lock().await.insert(message_id, tx);
this.outgoing
.lock()
.await
- .write_message(&proto::FromClient {
- id: message_id,
- variant: Some(req.to_variant()),
- })
+ .write_message(&req.into_envelope(message_id, None))
.await?;
let response = rx
.recv()
.await
.expect("response channel was unexpectedly dropped");
- T::Response::from_variant(response)
+ T::Response::from_envelope(response)
.ok_or_else(|| anyhow!("received response of the wrong t"))
}
}
- pub fn send<T: SendMessage>(self: &Arc<Self>, message: T) -> impl Future<Output = Result<()>> {
- let this = self.clone();
- async move {
- let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
- this.outgoing
- .lock()
- .await
- .write_message(&proto::FromClient {
- id: message_id,
- variant: Some(message.to_variant()),
- })
- .await?;
- Ok(())
- }
- }
-
- pub fn subscribe<T: SubscribeMessage>(
+ pub fn send<T: EnvelopedMessage>(
self: &Arc<Self>,
- subscription: T,
- ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Event>>>> {
+ message: T,
+ ) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
let message_id = this.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
- let (tx, rx) = mpsc::channel(256);
- this.response_channels
- .lock()
- .await
- .insert(message_id, (tx, false));
this.outgoing
.lock()
.await
- .write_message(&proto::FromClient {
- id: message_id,
- variant: Some(subscription.to_variant()),
- })
+ .write_message(&message.into_envelope(message_id, None))
.await?;
- Ok(rx.map(|event| {
- T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}"))
- }))
+ Ok(())
}
}
}
@@ -199,133 +154,49 @@ mod tests {
let mut server_stream = MessageStream::new(server_conn);
let client = RpcClient::new(client_conn, executor.clone());
- let client_req = client.request(proto::from_client::Auth {
+ let client_req = client.request(proto::Auth {
user_id: 42,
access_token: "token".to_string(),
});
smol::pin!(client_req);
- let server_req = send_recv(
- &mut client_req,
- server_stream.read_message::<proto::FromClient>(),
- )
- .await
- .unwrap();
+ let server_req = send_recv(&mut client_req, server_stream.read_message())
+ .await
+ .unwrap();
assert_eq!(
- server_req.variant,
- Some(proto::from_client::Variant::Auth(
- proto::from_client::Auth {
- user_id: 42,
- access_token: "token".to_string()
- }
- ))
+ server_req.payload,
+ Some(proto::envelope::Payload::Auth(proto::Auth {
+ user_id: 42,
+ access_token: "token".to_string()
+ }))
);
// Respond to another request to ensure requests are properly matched up.
server_stream
- .write_message(&proto::FromServer {
- request_id: Some(999),
- variant: Some(proto::from_server::Variant::AuthResponse(
- proto::from_server::AuthResponse {
- credentials_valid: false,
- },
- )),
- })
+ .write_message(
+ &proto::AuthResponse {
+ credentials_valid: false,
+ }
+ .into_envelope(1000, Some(999)),
+ )
.await
.unwrap();
server_stream
- .write_message(&proto::FromServer {
- request_id: Some(server_req.id),
- variant: Some(proto::from_server::Variant::AuthResponse(
- proto::from_server::AuthResponse {
- credentials_valid: true,
- },
- )),
- })
+ .write_message(
+ &proto::AuthResponse {
+ credentials_valid: true,
+ }
+ .into_envelope(1001, Some(server_req.id)),
+ )
.await
.unwrap();
assert_eq!(
client_req.await.unwrap(),
- proto::from_server::AuthResponse {
+ proto::AuthResponse {
credentials_valid: true
}
);
}
- #[gpui::test]
- async fn test_subscribe(cx: gpui::TestAppContext) {
- let executor = cx.read(|app| app.background_executor().clone());
- let socket_dir_path = TempDir::new("subscribe").unwrap();
- let socket_path = socket_dir_path.path().join(".sock");
- let listener = UnixListener::bind(&socket_path).unwrap();
- let client_conn = UnixStream::connect(&socket_path).await.unwrap();
- let (server_conn, _) = listener.accept().await.unwrap();
-
- let mut server_stream = MessageStream::new(server_conn);
- let client = RpcClient::new(client_conn, executor.clone());
-
- let mut events = client
- .subscribe(proto::from_client::SubscribeToPathRequests {})
- .await
- .unwrap();
-
- let subscription = server_stream
- .read_message::<proto::FromClient>()
- .await
- .unwrap();
- assert_eq!(
- subscription.variant,
- Some(proto::from_client::Variant::SubscribeToPathRequests(
- proto::from_client::SubscribeToPathRequests {}
- ))
- );
- server_stream
- .write_message(&proto::FromServer {
- request_id: Some(subscription.id),
- variant: Some(proto::from_server::Variant::PathRequest(
- proto::from_server::PathRequest {
- path: b"path-1".to_vec(),
- },
- )),
- })
- .await
- .unwrap();
- server_stream
- .write_message(&proto::FromServer {
- request_id: Some(99999),
- variant: Some(proto::from_server::Variant::PathRequest(
- proto::from_server::PathRequest {
- path: b"path-2".to_vec(),
- },
- )),
- })
- .await
- .unwrap();
- server_stream
- .write_message(&proto::FromServer {
- request_id: Some(subscription.id),
- variant: Some(proto::from_server::Variant::PathRequest(
- proto::from_server::PathRequest {
- path: b"path-3".to_vec(),
- },
- )),
- })
- .await
- .unwrap();
-
- assert_eq!(
- events.recv().await.unwrap().unwrap(),
- proto::from_server::PathRequest {
- path: b"path-1".to_vec()
- }
- );
- assert_eq!(
- events.recv().await.unwrap().unwrap(),
- proto::from_server::PathRequest {
- path: b"path-3".to_vec()
- }
- );
- }
-
#[gpui::test]
async fn test_drop_client(cx: gpui::TestAppContext) {
let executor = cx.read(|app| app.background_executor().clone());
@@ -362,7 +233,7 @@ mod tests {
let client = RpcClient::new(client_conn, executor.clone());
let err = client
- .request(proto::from_client::Auth {
+ .request(proto::Auth {
user_id: 42,
access_token: "token".to_string(),
})