@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use gpui::executor::Background;
use parking_lot::Mutex;
use postage::{
- oneshot,
+ mpsc, oneshot,
prelude::{Sink, Stream},
};
use smol::{
@@ -11,11 +11,13 @@ use smol::{
prelude::{AsyncRead, AsyncWrite},
};
use std::{collections::HashMap, sync::Arc};
-use zed_rpc::proto::{self, MessageStream, RequestMessage, SendMessage, ServerMessage};
+use zed_rpc::proto::{
+ self, MessageStream, RequestMessage, SendMessage, ServerMessage, SubscribeMessage,
+};
pub struct RpcClient<Conn> {
stream: MessageStream<WriteHalf<Conn>>,
- response_channels: Arc<Mutex<HashMap<i32, oneshot::Sender<proto::from_server::Variant>>>>,
+ response_channels: Arc<Mutex<HashMap<i32, (mpsc::Sender<proto::from_server::Variant>, bool)>>>,
next_message_id: i32,
_drop_tx: oneshot::Sender<()>,
}
@@ -59,9 +61,15 @@ where
Message::Message(message) => {
if let Some(variant) = message.variant {
if let Some(request_id) = message.request_id {
- let tx = response_channels.lock().remove(&request_id);
- if let Some(mut tx) = tx {
- tx.send(variant).await?;
+ let channel = response_channels.lock().remove(&request_id);
+ if let Some((mut tx, oneshot)) = channel {
+ if tx.send(variant).await.is_ok() {
+ if !oneshot {
+ response_channels
+ .lock()
+ .insert(request_id, (tx, false));
+ }
+ }
} else {
log::warn!(
"received RPC response to unknown request id {}",
@@ -85,10 +93,8 @@ where
pub async fn request<T: RequestMessage>(&mut self, req: T) -> Result<T::Response> {
let message_id = self.next_message_id;
self.next_message_id += 1;
-
- let (tx, mut rx) = oneshot::channel();
- self.response_channels.lock().insert(message_id, tx);
-
+ let (tx, mut rx) = mpsc::channel(1);
+ self.response_channels.lock().insert(message_id, (tx, true));
self.stream
.write_message(&proto::FromClient {
id: message_id,
@@ -114,6 +120,28 @@ where
.await?;
Ok(())
}
+
+ pub async fn subscribe<T: SubscribeMessage>(
+ &mut self,
+ subscription: T,
+ ) -> Result<impl Stream<Item = Result<T::Event>>> {
+ let message_id = self.next_message_id;
+ self.next_message_id += 1;
+ let (tx, rx) = mpsc::channel(256);
+ self.response_channels
+ .lock()
+ .insert(message_id, (tx, false));
+ self.stream
+ .write_message(&proto::FromClient {
+ id: message_id,
+ variant: Some(subscription.to_variant()),
+ })
+ .await?;
+
+ Ok(rx.map(|event| {
+ T::Event::from_variant(event).ok_or_else(|| anyhow!("invalid event {:?}"))
+ }))
+ }
}
#[cfg(test)]