Detailed changes
@@ -137,34 +137,6 @@ impl PeerExt for Peer {
}
}
-#[async_trait]
-impl PeerExt for zrpc::peer2::Peer {
- async fn sign_out(
- self: &Arc<Self>,
- connection_id: zrpc::ConnectionId,
- state: &AppState,
- ) -> tide::Result<()> {
- self.disconnect(connection_id).await;
- let worktree_ids = state.rpc.write().await.remove_connection(connection_id);
- for worktree_id in worktree_ids {
- let state = state.rpc.read().await;
- if let Some(worktree) = state.worktrees.get(&worktree_id) {
- rpc::broadcast(connection_id, worktree.connection_ids(), |conn_id| {
- self.send(
- conn_id,
- proto::RemovePeer {
- worktree_id,
- peer_id: connection_id.0,
- },
- )
- })
- .await?;
- }
- }
- Ok(())
- }
-}
-
pub fn build_client(client_id: &str, client_secret: &str) -> Client {
Client::new(
ClientId::new(client_id.to_string()),
@@ -30,13 +30,15 @@ use time::OffsetDateTime;
use zrpc::{
auth::random_token,
proto::{self, EnvelopedMessage},
- ConnectionId, Peer, Router, TypedEnvelope,
+ ConnectionId, Peer, TypedEnvelope,
};
type ReplicaId = u16;
type Handler = Box<
- dyn Fn(&mut Option<Box<dyn Any + Send + Sync>>, Arc<Server>) -> Option<BoxFuture<'static, ()>>,
+ dyn Send
+ + Sync
+ + Fn(&mut Option<Box<dyn Any + Send + Sync>>, Arc<Server>) -> Option<BoxFuture<'static, ()>>,
>;
#[derive(Default)]
@@ -48,7 +50,7 @@ struct ServerBuilder {
impl ServerBuilder {
pub fn on_message<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
- F: 'static + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut,
+ F: 'static + Send + Sync + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut,
Fut: 'static + Send + Future<Output = ()>,
M: EnvelopedMessage,
{
@@ -73,23 +75,23 @@ impl ServerBuilder {
self
}
- pub fn build(self, rpc: Arc<zrpc::peer2::Peer>, state: Arc<AppState>) -> Arc<Server> {
+ pub fn build(self, rpc: &Arc<Peer>, state: &Arc<AppState>) -> Arc<Server> {
Arc::new(Server {
- rpc,
- state,
+ rpc: rpc.clone(),
+ state: state.clone(),
handlers: self.handlers,
})
}
}
-struct Server {
- rpc: Arc<zrpc::peer2::Peer>,
+pub struct Server {
+ rpc: Arc<Peer>,
state: Arc<AppState>,
handlers: Vec<Handler>,
}
impl Server {
- pub async fn add_connection<Conn>(
+ pub async fn handle_connection<Conn>(
self: &Arc<Self>,
connection: Conn,
addr: String,
@@ -332,99 +334,31 @@ impl State {
}
}
-trait MessageHandler<'a, M: proto::EnvelopedMessage> {
- type Output: 'a + Send + Future<Output = tide::Result<()>>;
-
- fn handle(
- &self,
- message: TypedEnvelope<M>,
- rpc: &'a Arc<Peer>,
- app_state: &'a Arc<AppState>,
- ) -> Self::Output;
-}
-
-impl<'a, M, F, Fut> MessageHandler<'a, M> for F
-where
- M: proto::EnvelopedMessage,
- F: Fn(TypedEnvelope<M>, &'a Arc<Peer>, &'a Arc<AppState>) -> Fut,
- Fut: 'a + Send + Future<Output = tide::Result<()>>,
-{
- type Output = Fut;
-
- fn handle(
- &self,
- message: TypedEnvelope<M>,
- rpc: &'a Arc<Peer>,
- app_state: &'a Arc<AppState>,
- ) -> Self::Output {
- (self)(message, rpc, app_state)
- }
-}
-
-fn on_message<M, H>(router: &mut Router, rpc: &Arc<Peer>, app_state: &Arc<AppState>, handler: H)
-where
- M: EnvelopedMessage,
- H: 'static + Clone + Send + Sync + for<'a> MessageHandler<'a, M>,
-{
- let rpc = rpc.clone();
- let handler = handler.clone();
- let app_state = app_state.clone();
- router.add_message_handler(move |message| {
- let rpc = rpc.clone();
- let handler = handler.clone();
- let app_state = app_state.clone();
- async move {
- let sender_id = message.sender_id;
- let message_id = message.message_id;
- let start_time = Instant::now();
- log::info!(
- "RPC message received. id: {}.{}, type:{}",
- sender_id,
- message_id,
- M::NAME
- );
- if let Err(err) = handler.handle(message, &rpc, &app_state).await {
- log::error!("error handling message: {:?}", err);
- } else {
- log::info!(
- "RPC message handled. id:{}.{}, duration:{:?}",
- sender_id,
- message_id,
- start_time.elapsed()
- );
- }
-
- Ok(())
- }
- });
-}
-
-pub fn add_rpc_routes(router: &mut Router, state: &Arc<AppState>, rpc: &Arc<Peer>) {
- on_message(router, rpc, state, share_worktree);
- on_message(router, rpc, state, join_worktree);
- on_message(router, rpc, state, update_worktree);
- on_message(router, rpc, state, close_worktree);
- on_message(router, rpc, state, open_buffer);
- on_message(router, rpc, state, close_buffer);
- on_message(router, rpc, state, update_buffer);
- on_message(router, rpc, state, buffer_saved);
- on_message(router, rpc, state, save_buffer);
- on_message(router, rpc, state, get_channels);
- on_message(router, rpc, state, get_users);
- on_message(router, rpc, state, join_channel);
- on_message(router, rpc, state, send_channel_message);
+pub fn build_server(state: &Arc<AppState>, rpc: &Arc<Peer>) -> Arc<Server> {
+ ServerBuilder::default()
+ // .on_message(share_worktree)
+ // .on_message(join_worktree)
+ // .on_message(update_worktree)
+ // .on_message(close_worktree)
+ // .on_message(open_buffer)
+ // .on_message(close_buffer)
+ // .on_message(update_buffer)
+ // .on_message(buffer_saved)
+ // .on_message(save_buffer)
+ // .on_message(get_channels)
+ // .on_message(get_users)
+ // .on_message(join_channel)
+ // .on_message(send_channel_message)
+ .build(rpc, state)
}
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
- let mut router = Router::new();
- add_rpc_routes(&mut router, app.state(), rpc);
- let router = Arc::new(router);
+ let server = build_server(app.state(), rpc);
let rpc = rpc.clone();
app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
let user_id = request.ext::<UserId>().copied();
- let rpc = rpc.clone();
- let router = router.clone();
+ let server = server.clone();
async move {
const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
@@ -451,12 +385,11 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let http_res: &mut tide::http::Response = response.as_mut();
let upgrade_receiver = http_res.recv_upgrade().await;
let addr = request.remote().unwrap_or("unknown").to_string();
- let state = request.state().clone();
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
task::spawn(async move {
if let Some(stream) = upgrade_receiver.await {
let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
- handle_connection(rpc, router, state, addr, stream, user_id).await;
+ server.handle_connection(stream, addr, user_id).await;
}
});
@@ -465,43 +398,6 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
});
}
-pub async fn handle_connection<Conn>(
- rpc: Arc<Peer>,
- router: Arc<Router>,
- state: Arc<AppState>,
- addr: String,
- stream: Conn,
- user_id: UserId,
-) where
- Conn: 'static
- + futures::Sink<WebSocketMessage, Error = WebSocketError>
- + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
- + Send
- + Unpin,
-{
- log::info!("accepted rpc connection: {:?}", addr);
- let (connection_id, handle_io, handle_messages) = rpc.add_connection(stream, router).await;
- state
- .rpc
- .write()
- .await
- .add_connection(connection_id, user_id);
-
- let handle_messages = async move {
- handle_messages.await;
- Ok(())
- };
-
- if let Err(e) = futures::try_join!(handle_messages, handle_io) {
- log::error!("error handling rpc connection {:?} - {:?}", addr, e);
- }
-
- log::info!("closing connection to {:?}", addr);
- if let Err(e) = rpc.sign_out(connection_id, &state).await {
- log::error!("error signing out connection {:?} - {:?}", addr, e);
- }
-}
-
async fn share_worktree(
mut request: TypedEnvelope<proto::ShareWorktree>,
rpc: &Arc<Peer>,
@@ -2,7 +2,7 @@ use crate::{
auth,
db::{self, UserId},
github,
- rpc::{self, add_rpc_routes},
+ rpc::{self, build_server},
AppState, Config,
};
use async_std::task;
@@ -24,7 +24,7 @@ use zed::{
test::Channel,
worktree::Worktree,
};
-use zrpc::{ForegroundRouter, Peer, Router};
+use zrpc::Peer;
#[gpui::test]
async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
@@ -541,7 +541,7 @@ impl TestServer {
let app_state = Self::build_app_state(&db_name).await;
let peer = Peer::new();
let mut router = Router::new();
- add_rpc_routes(&mut router, &app_state, &peer);
+ build_server(&mut router, &app_state, &peer);
Self {
peer,
router: Arc::new(router),
@@ -24,14 +24,12 @@ pub use settings::Settings;
use parking_lot::Mutex;
use postage::watch;
use std::sync::Arc;
-use zrpc::ForegroundRouter;
pub struct AppState {
pub settings_tx: Arc<Mutex<watch::Sender<Settings>>>,
pub settings: watch::Receiver<Settings>,
pub languages: Arc<language::LanguageRegistry>,
pub themes: Arc<settings::ThemeRegistry>,
- pub rpc_router: Arc<ForegroundRouter>,
pub rpc: rpc::Client,
pub fs: Arc<dyn fs::Fs>,
}
@@ -13,7 +13,7 @@ use zrpc::proto::EntityMessage;
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
use zrpc::{
proto::{EnvelopedMessage, RequestMessage},
- ForegroundRouter, Peer, Receipt,
+ Peer, Receipt,
};
lazy_static! {
@@ -43,25 +43,6 @@ impl Client {
}
}
- pub fn on_message<H, M>(
- &self,
- router: &mut ForegroundRouter,
- handler: H,
- cx: &mut gpui::MutableAppContext,
- ) where
- H: 'static + Clone + for<'a> MessageHandler<'a, M>,
- M: proto::EnvelopedMessage,
- {
- let this = self.clone();
- let cx = cx.to_async();
- router.add_message_handler(move |message| {
- let this = this.clone();
- let mut cx = cx.clone();
- let handler = handler.clone();
- async move { handler.handle(message, &this, &mut cx).await }
- });
- }
-
pub fn subscribe_from_model<T, M, F>(
&self,
remote_id: u64,
@@ -90,11 +71,7 @@ impl Client {
})
}
- pub async fn log_in_and_connect(
- &self,
- router: Arc<ForegroundRouter>,
- cx: AsyncAppContext,
- ) -> surf::Result<()> {
+ pub async fn log_in_and_connect(&self, cx: AsyncAppContext) -> surf::Result<()> {
if self.state.read().await.connection_id.is_some() {
return Ok(());
}
@@ -111,13 +88,13 @@ impl Client {
.await
.context("websocket handshake")?;
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
- self.add_connection(stream, router, cx).await?;
+ self.add_connection(stream, cx).await?;
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
let (stream, _) = async_tungstenite::client_async(request, stream).await?;
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
- self.add_connection(stream, router, cx).await?;
+ self.add_connection(stream, cx).await?;
} else {
return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
};
@@ -125,12 +102,7 @@ impl Client {
Ok(())
}
- pub async fn add_connection<Conn>(
- &self,
- conn: Conn,
- router: Arc<ForegroundRouter>,
- cx: AsyncAppContext,
- ) -> surf::Result<()>
+ pub async fn add_connection<Conn>(&self, conn: Conn, cx: AsyncAppContext) -> surf::Result<()>
where
Conn: 'static
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
@@ -138,8 +110,7 @@ impl Client {
+ Unpin
+ Send,
{
- let (connection_id, handle_io, handle_messages) =
- self.peer.add_connection(conn, router).await;
+ let (connection_id, handle_io, handle_messages) = self.peer.add_connection(conn).await;
cx.foreground().spawn(handle_messages).detach();
cx.background()
.spawn(async move {
@@ -15,7 +15,6 @@ use std::{
sync::Arc,
};
use tempdir::TempDir;
-use zrpc::ForegroundRouter;
#[cfg(feature = "test-support")]
pub use zrpc::test::Channel;
@@ -163,7 +162,6 @@ pub fn build_app_state(cx: &AppContext) -> Arc<AppState> {
settings,
themes,
languages: languages.clone(),
- rpc_router: Arc::new(ForegroundRouter::new()),
rpc: rpc::Client::new(languages),
fs: Arc::new(RealFs),
})
@@ -728,10 +728,9 @@ impl Workspace {
fn share_worktree(&mut self, app_state: &Arc<AppState>, cx: &mut ViewContext<Self>) {
let rpc = self.rpc.clone();
let platform = cx.platform();
- let router = app_state.rpc_router.clone();
let task = cx.spawn(|this, mut cx| async move {
- rpc.log_in_and_connect(router, cx.clone()).await?;
+ rpc.log_in_and_connect(cx.clone()).await?;
let share_task = this.update(&mut cx, |this, cx| {
let worktree = this.worktrees.iter().next()?;
@@ -761,10 +760,9 @@ impl Workspace {
fn join_worktree(&mut self, app_state: &Arc<AppState>, cx: &mut ViewContext<Self>) {
let rpc = self.rpc.clone();
let languages = self.languages.clone();
- let router = app_state.rpc_router.clone();
let task = cx.spawn(|this, mut cx| async move {
- rpc.log_in_and_connect(router, cx.clone()).await?;
+ rpc.log_in_and_connect(cx.clone()).await?;
let worktree_url = cx
.platform()
@@ -1,6 +1,5 @@
pub mod auth;
mod peer;
-pub mod peer2;
pub mod proto;
#[cfg(any(test, feature = "test-support"))]
pub mod test;
@@ -2,17 +2,14 @@ use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
use anyhow::{anyhow, Context, Result};
use async_lock::{Mutex, RwLock};
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
-use futures::{
- future::{self, BoxFuture, LocalBoxFuture},
- FutureExt, Stream, StreamExt,
-};
+use futures::{FutureExt, StreamExt};
use postage::{
- broadcast, mpsc,
+ mpsc,
prelude::{Sink as _, Stream as _},
};
use std::{
- any::{Any, TypeId},
- collections::{HashMap, HashSet},
+ any::Any,
+ collections::HashMap,
fmt,
future::Future,
marker::PhantomData,
@@ -25,17 +22,20 @@ use std::{
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct ConnectionId(pub u32);
+impl fmt::Display for ConnectionId {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct PeerId(pub u32);
-type MessageHandler = Box<
- dyn Send
- + Sync
- + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<'static, ()>>,
->;
-
-type ForegroundMessageHandler =
- Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
+impl fmt::Display for PeerId {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.0.fmt(f)
+ }
+}
pub struct Receipt<T> {
pub sender_id: ConnectionId,
@@ -43,6 +43,18 @@ pub struct Receipt<T> {
payload_type: PhantomData<T>,
}
+impl<T> Clone for Receipt<T> {
+ fn clone(&self) -> Self {
+ Self {
+ sender_id: self.sender_id,
+ message_id: self.message_id,
+ payload_type: PhantomData,
+ }
+ }
+}
+
+impl<T> Copy for Receipt<T> {}
+
pub struct TypedEnvelope<T> {
pub sender_id: ConnectionId,
pub original_sender_id: Option<PeerId>,
@@ -67,17 +79,9 @@ impl<T: RequestMessage> TypedEnvelope<T> {
}
}
-pub type Router = RouterInternal<MessageHandler>;
-pub type ForegroundRouter = RouterInternal<ForegroundMessageHandler>;
-pub struct RouterInternal<H> {
- message_handlers: Vec<H>,
- handler_types: HashSet<TypeId>,
-}
-
pub struct Peer {
connections: RwLock<HashMap<ConnectionId, Connection>>,
next_connection_id: AtomicU32,
- incoming_messages: broadcast::Sender<Arc<dyn Any + Send + Sync>>,
}
#[derive(Clone)]
@@ -92,22 +96,18 @@ impl Peer {
Arc::new(Self {
connections: Default::default(),
next_connection_id: Default::default(),
- incoming_messages: broadcast::channel(256).0,
})
}
- pub async fn add_connection<Conn, H, Fut>(
+ pub async fn add_connection<Conn>(
self: &Arc<Self>,
conn: Conn,
- router: Arc<RouterInternal<H>>,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
- impl Future<Output = ()>,
+ mpsc::Receiver<Box<dyn Any + Sync + Send>>,
)
where
- H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
- Fut: Future<Output = ()>,
Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
+ futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+ Send
@@ -118,7 +118,7 @@ impl Peer {
self.next_connection_id
.fetch_add(1, atomic::Ordering::SeqCst),
);
- let (mut incoming_tx, mut incoming_rx) = mpsc::channel(64);
+ let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
let connection = Connection {
outgoing_tx,
@@ -128,6 +128,7 @@ impl Peer {
let mut writer = MessageStream::new(tx);
let mut reader = MessageStream::new(rx);
+ let response_channels = connection.response_channels.clone();
let handle_io = async move {
loop {
let read_message = reader.read_message().fuse();
@@ -136,57 +137,54 @@ impl Peer {
futures::select_biased! {
incoming = read_message => match incoming {
Ok(incoming) => {
- if incoming_tx.send(incoming).await.is_err() {
- return Ok(());
+ if let Some(responding_to) = incoming.responding_to {
+ let channel = response_channels.lock().await.remove(&responding_to);
+ if let Some(mut tx) = channel {
+ tx.send(incoming).await.ok();
+ } else {
+ log::warn!("received RPC response to unknown request {}", responding_to);
+ }
+ } else {
+ if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
+ if incoming_tx.send(envelope).await.is_err() {
+ response_channels.lock().await.clear();
+ return Ok(())
+ }
+ } else {
+ log::error!("unable to construct a typed envelope");
+ }
}
+
break;
}
Err(error) => {
+ response_channels.lock().await.clear();
Err(error).context("received invalid RPC message")?;
}
},
outgoing = outgoing_rx.recv().fuse() => match outgoing {
Some(outgoing) => {
if let Err(result) = writer.write_message(&outgoing).await {
+ response_channels.lock().await.clear();
Err(result).context("failed to write RPC message")?;
}
}
- None => return Ok(()),
+ None => {
+ response_channels.lock().await.clear();
+ return Ok(())
+ }
}
}
}
}
};
- let mut broadcast_incoming_messages = self.incoming_messages.clone();
- let response_channels = connection.response_channels.clone();
- let handle_messages = async move {
- while let Some(envelope) = incoming_rx.recv().await {
- if let Some(responding_to) = envelope.responding_to {
- let channel = response_channels.lock().await.remove(&responding_to);
- if let Some(mut tx) = channel {
- tx.send(envelope).await.ok();
- } else {
- log::warn!("received RPC response to unknown request {}", responding_to);
- }
- } else {
- router.handle(connection_id, envelope.clone()).await;
- if let Some(envelope) = proto::build_typed_envelope(connection_id, envelope) {
- broadcast_incoming_messages.send(Arc::from(envelope)).await.ok();
- } else {
- log::error!("unable to construct a typed envelope");
- }
- }
- }
- response_channels.lock().await.clear();
- };
-
self.connections
.write()
.await
.insert(connection_id, connection);
- (connection_id, handle_io, handle_messages)
+ (connection_id, handle_io, incoming_rx)
}
pub async fn disconnect(&self, connection_id: ConnectionId) {
@@ -197,12 +195,6 @@ impl Peer {
self.connections.write().await.clear();
}
- pub fn subscribe<T: EnvelopedMessage>(&self) -> impl Stream<Item = Arc<TypedEnvelope<T>>> {
- self.incoming_messages
- .subscribe()
- .filter_map(|envelope| future::ready(Arc::downcast(envelope).ok()))
- }
-
pub fn request<T: RequestMessage>(
self: &Arc<Self>,
receiver_id: ConnectionId,
@@ -325,142 +317,10 @@ impl Peer {
}
}
-impl<H, Fut> RouterInternal<H>
-where
- H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
- Fut: Future<Output = ()>,
-{
- pub fn new() -> Self {
- Self {
- message_handlers: Default::default(),
- handler_types: Default::default(),
- }
- }
-
- async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) {
- let mut envelope = Some(message);
- for handler in self.message_handlers.iter() {
- if let Some(future) = handler(&mut envelope, connection_id) {
- future.await;
- return;
- }
- }
- log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
- }
-}
-
-impl Router {
- pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
- where
- T: EnvelopedMessage,
- Fut: 'static + Send + Future<Output = Result<()>>,
- F: 'static + Send + Sync + Fn(TypedEnvelope<T>) -> Fut,
- {
- if !self.handler_types.insert(TypeId::of::<T>()) {
- panic!("duplicate handler type");
- }
-
- self.message_handlers
- .push(Box::new(move |envelope, connection_id| {
- if envelope.as_ref().map_or(false, T::matches_envelope) {
- let envelope = Option::take(envelope).unwrap();
- let message_id = envelope.id;
- let future = handler(TypedEnvelope {
- sender_id: connection_id,
- original_sender_id: envelope.original_sender_id.map(PeerId),
- message_id,
- payload: T::from_envelope(envelope).unwrap(),
- });
- Some(
- async move {
- if let Err(error) = future.await {
- log::error!(
- "error handling message {} {}: {:?}",
- T::NAME,
- message_id,
- error
- );
- }
- }
- .boxed(),
- )
- } else {
- None
- }
- }));
- }
-}
-
-impl ForegroundRouter {
- pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
- where
- T: EnvelopedMessage,
- Fut: 'static + Future<Output = Result<()>>,
- F: 'static + Fn(TypedEnvelope<T>) -> Fut,
- {
- if !self.handler_types.insert(TypeId::of::<T>()) {
- panic!("duplicate handler type");
- }
-
- self.message_handlers
- .push(Box::new(move |envelope, connection_id| {
- if envelope.as_ref().map_or(false, T::matches_envelope) {
- let envelope = Option::take(envelope).unwrap();
- let message_id = envelope.id;
- let future = handler(TypedEnvelope {
- sender_id: connection_id,
- original_sender_id: envelope.original_sender_id.map(PeerId),
- message_id,
- payload: T::from_envelope(envelope).unwrap(),
- });
- Some(
- async move {
- if let Err(error) = future.await {
- log::error!(
- "error handling message {} {}: {:?}",
- T::NAME,
- message_id,
- error
- );
- }
- }
- .boxed_local(),
- )
- } else {
- None
- }
- }));
- }
-}
-
-impl<T> Clone for Receipt<T> {
- fn clone(&self) -> Self {
- Self {
- sender_id: self.sender_id,
- message_id: self.message_id,
- payload_type: PhantomData,
- }
- }
-}
-
-impl<T> Copy for Receipt<T> {}
-
-impl fmt::Display for ConnectionId {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- self.0.fmt(f)
- }
-}
-
-impl fmt::Display for PeerId {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- self.0.fmt(f)
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
- use crate::test;
+ use crate::{test, TypedEnvelope};
#[test]
fn test_request_response() {
@@ -470,139 +330,37 @@ mod tests {
let client1 = Peer::new();
let client2 = Peer::new();
- let mut router = Router::new();
- router.add_message_handler({
- let server = server.clone();
- move |envelope: TypedEnvelope<proto::Auth>| {
- let server = server.clone();
- async move {
- let receipt = envelope.receipt();
- let message = envelope.payload;
- server
- .respond(
- receipt,
- match message.user_id {
- 1 => {
- assert_eq!(message.access_token, "access-token-1");
- proto::AuthResponse {
- credentials_valid: true,
- }
- }
- 2 => {
- assert_eq!(message.access_token, "access-token-2");
- proto::AuthResponse {
- credentials_valid: false,
- }
- }
- _ => {
- panic!("unexpected user id {}", message.user_id);
- }
- },
- )
- .await
- }
- }
- });
-
- router.add_message_handler({
- let server = server.clone();
- move |envelope: TypedEnvelope<proto::OpenBuffer>| {
- let server = server.clone();
- async move {
- let receipt = envelope.receipt();
- let message = envelope.payload;
- server
- .respond(
- receipt,
- match message.path.as_str() {
- "path/one" => {
- assert_eq!(message.worktree_id, 1);
- proto::OpenBufferResponse {
- buffer: Some(proto::Buffer {
- id: 101,
- content: "path/one content".to_string(),
- history: vec![],
- selections: vec![],
- }),
- }
- }
- "path/two" => {
- assert_eq!(message.worktree_id, 2);
- proto::OpenBufferResponse {
- buffer: Some(proto::Buffer {
- id: 102,
- content: "path/two content".to_string(),
- history: vec![],
- selections: vec![],
- }),
- }
- }
- _ => {
- panic!("unexpected path {}", message.path);
- }
- },
- )
- .await
- }
- }
- });
- let router = Arc::new(router);
-
let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
- let (client1_conn_id, io_task1, msg_task1) = client1
- .add_connection(client1_to_server_conn, router.clone())
- .await;
- let (_, io_task2, msg_task2) = server
- .add_connection(server_to_client_1_conn, router.clone())
- .await;
+ let (client1_conn_id, io_task1, _) =
+ client1.add_connection(client1_to_server_conn).await;
+ let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
- let (client2_conn_id, io_task3, msg_task3) = client2
- .add_connection(client2_to_server_conn, router.clone())
- .await;
- let (_, io_task4, msg_task4) = server
- .add_connection(server_to_client_2_conn, router.clone())
- .await;
+ let (client2_conn_id, io_task3, _) =
+ client2.add_connection(client2_to_server_conn).await;
+ let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
smol::spawn(io_task1).detach();
smol::spawn(io_task2).detach();
smol::spawn(io_task3).detach();
smol::spawn(io_task4).detach();
- smol::spawn(msg_task1).detach();
- smol::spawn(msg_task2).detach();
- smol::spawn(msg_task3).detach();
- smol::spawn(msg_task4).detach();
+ smol::spawn(handle_messages(incoming1, server.clone())).detach();
+ smol::spawn(handle_messages(incoming2, server.clone())).detach();
assert_eq!(
client1
- .request(
- client1_conn_id,
- proto::Auth {
- user_id: 1,
- access_token: "access-token-1".to_string(),
- },
- )
+ .request(client1_conn_id, proto::Ping { id: 1 },)
.await
.unwrap(),
- proto::AuthResponse {
- credentials_valid: true,
- }
+ proto::Pong { id: 1 }
);
assert_eq!(
client2
- .request(
- client2_conn_id,
- proto::Auth {
- user_id: 2,
- access_token: "access-token-2".to_string(),
- },
- )
+ .request(client2_conn_id, proto::Ping { id: 2 },)
.await
.unwrap(),
- proto::AuthResponse {
- credentials_valid: false,
- }
+ proto::Pong { id: 2 }
);
assert_eq!(
@@ -649,6 +407,62 @@ mod tests {
client1.disconnect(client1_conn_id).await;
client2.disconnect(client1_conn_id).await;
+
+ async fn handle_messages(
+ mut messages: mpsc::Receiver<Box<dyn Any + Sync + Send>>,
+ peer: Arc<Peer>,
+ ) -> Result<()> {
+ while let Some(envelope) = messages.next().await {
+ if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
+ let receipt = envelope.receipt();
+ peer.respond(
+ receipt,
+ proto::Pong {
+ id: envelope.payload.id,
+ },
+ )
+ .await?
+ } else if let Some(envelope) =
+ envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
+ {
+ let message = &envelope.payload;
+ let receipt = envelope.receipt();
+ let response = match message.path.as_str() {
+ "path/one" => {
+ assert_eq!(message.worktree_id, 1);
+ proto::OpenBufferResponse {
+ buffer: Some(proto::Buffer {
+ id: 101,
+ content: "path/one content".to_string(),
+ history: vec![],
+ selections: vec![],
+ }),
+ }
+ }
+ "path/two" => {
+ assert_eq!(message.worktree_id, 2);
+ proto::OpenBufferResponse {
+ buffer: Some(proto::Buffer {
+ id: 102,
+ content: "path/two content".to_string(),
+ history: vec![],
+ selections: vec![],
+ }),
+ }
+ }
+ _ => {
+ panic!("unexpected path {}", message.path);
+ }
+ };
+
+ peer.respond(receipt, response).await?
+ } else {
+ panic!("unknown message type");
+ }
+ }
+
+ Ok(())
+ }
});
}
@@ -658,9 +472,8 @@ mod tests {
let (client_conn, mut server_conn) = test::Channel::bidirectional();
let client = Peer::new();
- let router = Arc::new(Router::new());
- let (connection_id, io_handler, message_handler) =
- client.add_connection(client_conn, router).await;
+ let (connection_id, io_handler, mut incoming) =
+ client.add_connection(client_conn).await;
let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
smol::spawn(async move {
@@ -671,7 +484,7 @@ mod tests {
let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
smol::spawn(async move {
- message_handler.await;
+ incoming.next().await;
messages_ended_tx.send(()).await.unwrap();
})
.detach();
@@ -695,11 +508,10 @@ mod tests {
drop(server_conn);
let client = Peer::new();
- let router = Arc::new(Router::new());
- let (connection_id, io_handler, message_handler) =
- client.add_connection(client_conn, router).await;
+ let (connection_id, io_handler, mut incoming) =
+ client.add_connection(client_conn).await;
smol::spawn(io_handler).detach();
- smol::spawn(message_handler).detach();
+ smol::spawn(async move { incoming.next().await }).detach();
let err = client
.request(
@@ -1,470 +0,0 @@
-use crate::{
- proto::{self, EnvelopedMessage, MessageStream, RequestMessage},
- ConnectionId, PeerId, Receipt,
-};
-use anyhow::{anyhow, Context, Result};
-use async_lock::{Mutex, RwLock};
-use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
-use futures::{FutureExt, StreamExt};
-use postage::{
- mpsc,
- prelude::{Sink as _, Stream as _},
-};
-use std::{
- any::Any,
- collections::HashMap,
- future::Future,
- sync::{
- atomic::{self, AtomicU32},
- Arc,
- },
-};
-
-pub struct Peer {
- connections: RwLock<HashMap<ConnectionId, Connection>>,
- next_connection_id: AtomicU32,
-}
-
-#[derive(Clone)]
-struct Connection {
- outgoing_tx: mpsc::Sender<proto::Envelope>,
- next_message_id: Arc<AtomicU32>,
- response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
-}
-
-impl Peer {
- pub fn new() -> Arc<Self> {
- Arc::new(Self {
- connections: Default::default(),
- next_connection_id: Default::default(),
- })
- }
-
- pub async fn add_connection<Conn>(
- self: &Arc<Self>,
- conn: Conn,
- ) -> (
- ConnectionId,
- impl Future<Output = anyhow::Result<()>> + Send,
- mpsc::Receiver<Box<dyn Any + Sync + Send>>,
- )
- where
- Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
- + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
- + Send
- + Unpin,
- {
- let (tx, rx) = conn.split();
- let connection_id = ConnectionId(
- self.next_connection_id
- .fetch_add(1, atomic::Ordering::SeqCst),
- );
- let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
- let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
- let connection = Connection {
- outgoing_tx,
- next_message_id: Default::default(),
- response_channels: Default::default(),
- };
- let mut writer = MessageStream::new(tx);
- let mut reader = MessageStream::new(rx);
-
- let response_channels = connection.response_channels.clone();
- let handle_io = async move {
- loop {
- let read_message = reader.read_message().fuse();
- futures::pin_mut!(read_message);
- loop {
- futures::select_biased! {
- incoming = read_message => match incoming {
- Ok(incoming) => {
- if let Some(responding_to) = incoming.responding_to {
- let channel = response_channels.lock().await.remove(&responding_to);
- if let Some(mut tx) = channel {
- tx.send(incoming).await.ok();
- } else {
- log::warn!("received RPC response to unknown request {}", responding_to);
- }
- } else {
- if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
- if incoming_tx.send(envelope).await.is_err() {
- response_channels.lock().await.clear();
- return Ok(())
- }
- } else {
- log::error!("unable to construct a typed envelope");
- }
- }
-
- break;
- }
- Err(error) => {
- response_channels.lock().await.clear();
- Err(error).context("received invalid RPC message")?;
- }
- },
- outgoing = outgoing_rx.recv().fuse() => match outgoing {
- Some(outgoing) => {
- if let Err(result) = writer.write_message(&outgoing).await {
- response_channels.lock().await.clear();
- Err(result).context("failed to write RPC message")?;
- }
- }
- None => {
- response_channels.lock().await.clear();
- return Ok(())
- }
- }
- }
- }
- }
- };
-
- self.connections
- .write()
- .await
- .insert(connection_id, connection);
-
- (connection_id, handle_io, incoming_rx)
- }
-
- pub async fn disconnect(&self, connection_id: ConnectionId) {
- self.connections.write().await.remove(&connection_id);
- }
-
- pub async fn reset(&self) {
- self.connections.write().await.clear();
- }
-
- pub fn request<T: RequestMessage>(
- self: &Arc<Self>,
- receiver_id: ConnectionId,
- request: T,
- ) -> impl Future<Output = Result<T::Response>> {
- self.request_internal(None, receiver_id, request)
- }
-
- pub fn forward_request<T: RequestMessage>(
- self: &Arc<Self>,
- sender_id: ConnectionId,
- receiver_id: ConnectionId,
- request: T,
- ) -> impl Future<Output = Result<T::Response>> {
- self.request_internal(Some(sender_id), receiver_id, request)
- }
-
- pub fn request_internal<T: RequestMessage>(
- self: &Arc<Self>,
- original_sender_id: Option<ConnectionId>,
- receiver_id: ConnectionId,
- request: T,
- ) -> impl Future<Output = Result<T::Response>> {
- let this = self.clone();
- let (tx, mut rx) = mpsc::channel(1);
- async move {
- let mut connection = this.connection(receiver_id).await?;
- let message_id = connection
- .next_message_id
- .fetch_add(1, atomic::Ordering::SeqCst);
- connection
- .response_channels
- .lock()
- .await
- .insert(message_id, tx);
- connection
- .outgoing_tx
- .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
- .await
- .map_err(|_| anyhow!("connection was closed"))?;
- let response = rx
- .recv()
- .await
- .ok_or_else(|| anyhow!("connection was closed"))?;
- T::Response::from_envelope(response)
- .ok_or_else(|| anyhow!("received response of the wrong type"))
- }
- }
-
- pub fn send<T: EnvelopedMessage>(
- self: &Arc<Self>,
- receiver_id: ConnectionId,
- message: T,
- ) -> impl Future<Output = Result<()>> {
- let this = self.clone();
- async move {
- let mut connection = this.connection(receiver_id).await?;
- let message_id = connection
- .next_message_id
- .fetch_add(1, atomic::Ordering::SeqCst);
- connection
- .outgoing_tx
- .send(message.into_envelope(message_id, None, None))
- .await?;
- Ok(())
- }
- }
-
- pub fn forward_send<T: EnvelopedMessage>(
- self: &Arc<Self>,
- sender_id: ConnectionId,
- receiver_id: ConnectionId,
- message: T,
- ) -> impl Future<Output = Result<()>> {
- let this = self.clone();
- async move {
- let mut connection = this.connection(receiver_id).await?;
- let message_id = connection
- .next_message_id
- .fetch_add(1, atomic::Ordering::SeqCst);
- connection
- .outgoing_tx
- .send(message.into_envelope(message_id, None, Some(sender_id.0)))
- .await?;
- Ok(())
- }
- }
-
- pub fn respond<T: RequestMessage>(
- self: &Arc<Self>,
- receipt: Receipt<T>,
- response: T::Response,
- ) -> impl Future<Output = Result<()>> {
- let this = self.clone();
- async move {
- let mut connection = this.connection(receipt.sender_id).await?;
- let message_id = connection
- .next_message_id
- .fetch_add(1, atomic::Ordering::SeqCst);
- connection
- .outgoing_tx
- .send(response.into_envelope(message_id, Some(receipt.message_id), None))
- .await?;
- Ok(())
- }
- }
-
- fn connection(
- self: &Arc<Self>,
- connection_id: ConnectionId,
- ) -> impl Future<Output = Result<Connection>> {
- let this = self.clone();
- async move {
- let connections = this.connections.read().await;
- let connection = connections
- .get(&connection_id)
- .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
- Ok(connection.clone())
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use crate::{test, TypedEnvelope};
-
- #[test]
- fn test_request_response() {
- smol::block_on(async move {
- // create 2 clients connected to 1 server
- let server = Peer::new();
- let client1 = Peer::new();
- let client2 = Peer::new();
-
- let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
- let (client1_conn_id, io_task1, _) =
- client1.add_connection(client1_to_server_conn).await;
- let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
-
- let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
- let (client2_conn_id, io_task3, _) =
- client2.add_connection(client2_to_server_conn).await;
- let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
-
- smol::spawn(io_task1).detach();
- smol::spawn(io_task2).detach();
- smol::spawn(io_task3).detach();
- smol::spawn(io_task4).detach();
- smol::spawn(handle_messages(incoming1, server.clone())).detach();
- smol::spawn(handle_messages(incoming2, server.clone())).detach();
-
- assert_eq!(
- client1
- .request(client1_conn_id, proto::Ping { id: 1 },)
- .await
- .unwrap(),
- proto::Pong { id: 1 }
- );
-
- assert_eq!(
- client2
- .request(client2_conn_id, proto::Ping { id: 2 },)
- .await
- .unwrap(),
- proto::Pong { id: 2 }
- );
-
- assert_eq!(
- client1
- .request(
- client1_conn_id,
- proto::OpenBuffer {
- worktree_id: 1,
- path: "path/one".to_string(),
- },
- )
- .await
- .unwrap(),
- proto::OpenBufferResponse {
- buffer: Some(proto::Buffer {
- id: 101,
- content: "path/one content".to_string(),
- history: vec![],
- selections: vec![],
- }),
- }
- );
-
- assert_eq!(
- client2
- .request(
- client2_conn_id,
- proto::OpenBuffer {
- worktree_id: 2,
- path: "path/two".to_string(),
- },
- )
- .await
- .unwrap(),
- proto::OpenBufferResponse {
- buffer: Some(proto::Buffer {
- id: 102,
- content: "path/two content".to_string(),
- history: vec![],
- selections: vec![],
- }),
- }
- );
-
- client1.disconnect(client1_conn_id).await;
- client2.disconnect(client1_conn_id).await;
-
- async fn handle_messages(
- mut messages: mpsc::Receiver<Box<dyn Any + Sync + Send>>,
- peer: Arc<Peer>,
- ) -> Result<()> {
- while let Some(envelope) = messages.next().await {
- if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
- let receipt = envelope.receipt();
- peer.respond(
- receipt,
- proto::Pong {
- id: envelope.payload.id,
- },
- )
- .await?
- } else if let Some(envelope) =
- envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
- {
- let message = &envelope.payload;
- let receipt = envelope.receipt();
- let response = match message.path.as_str() {
- "path/one" => {
- assert_eq!(message.worktree_id, 1);
- proto::OpenBufferResponse {
- buffer: Some(proto::Buffer {
- id: 101,
- content: "path/one content".to_string(),
- history: vec![],
- selections: vec![],
- }),
- }
- }
- "path/two" => {
- assert_eq!(message.worktree_id, 2);
- proto::OpenBufferResponse {
- buffer: Some(proto::Buffer {
- id: 102,
- content: "path/two content".to_string(),
- history: vec![],
- selections: vec![],
- }),
- }
- }
- _ => {
- panic!("unexpected path {}", message.path);
- }
- };
-
- peer.respond(receipt, response).await?
- } else {
- panic!("unknown message type");
- }
- }
-
- Ok(())
- }
- });
- }
-
- #[test]
- fn test_disconnect() {
- smol::block_on(async move {
- let (client_conn, mut server_conn) = test::Channel::bidirectional();
-
- let client = Peer::new();
- let (connection_id, io_handler, mut incoming) =
- client.add_connection(client_conn).await;
-
- let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
- smol::spawn(async move {
- io_handler.await.ok();
- io_ended_tx.send(()).await.unwrap();
- })
- .detach();
-
- let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
- smol::spawn(async move {
- incoming.next().await;
- messages_ended_tx.send(()).await.unwrap();
- })
- .detach();
-
- client.disconnect(connection_id).await;
-
- io_ended_rx.recv().await;
- messages_ended_rx.recv().await;
- assert!(
- futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
- .await
- .is_err()
- );
- });
- }
-
- #[test]
- fn test_io_error() {
- smol::block_on(async move {
- let (client_conn, server_conn) = test::Channel::bidirectional();
- drop(server_conn);
-
- let client = Peer::new();
- let (connection_id, io_handler, mut incoming) =
- client.add_connection(client_conn).await;
- smol::spawn(io_handler).detach();
- smol::spawn(async move { incoming.next().await }).detach();
-
- let err = client
- .request(
- connection_id,
- proto::Auth {
- user_id: 42,
- access_token: "token".to_string(),
- },
- )
- .await
- .unwrap_err();
- assert_eq!(err.to_string(), "connection was closed");
- });
- }
-}