@@ -35,23 +35,26 @@ use zrpc::{
type ReplicaId = u16;
-type Handler = Box<
+type MessageHandler = Box<
dyn Send
+ Sync
- + Fn(&mut Option<Box<dyn Any + Send + Sync>>, Arc<Server>) -> Option<BoxFuture<'static, ()>>,
+ + Fn(
+ &mut Option<Box<dyn Any + Send + Sync>>,
+ Arc<Server>,
+ ) -> Option<BoxFuture<'static, tide::Result<()>>>,
>;
#[derive(Default)]
struct ServerBuilder {
- handlers: Vec<Handler>,
+ handlers: Vec<MessageHandler>,
handler_types: HashSet<TypeId>,
}
impl ServerBuilder {
- pub fn on_message<F, Fut, M>(&mut self, handler: F) -> &mut Self
+ pub fn on_message<F, Fut, M>(mut self, handler: F) -> Self
where
F: 'static + Send + Sync + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut,
- Fut: 'static + Send + Future<Output = ()>,
+ Fut: 'static + Send + Future<Output = tide::Result<()>>,
M: EnvelopedMessage,
{
if self.handler_types.insert(TypeId::of::<M>()) {
@@ -87,7 +90,7 @@ impl ServerBuilder {
pub struct Server {
rpc: Arc<Peer>,
state: Arc<AppState>,
- handlers: Vec<Handler>,
+ handlers: Vec<MessageHandler>,
}
impl Server {
@@ -119,10 +122,16 @@ impl Server {
futures::select_biased! {
message = next_message => {
if let Some(message) = message {
+ let start_time = Instant::now();
+ log::info!("RPC message received");
let mut message = Some(message);
for handler in &this.handlers {
if let Some(future) = (handler)(&mut message, this.clone()) {
- future.await;
+ if let Err(err) = future.await {
+ log::error!("error handling message: {:?}", err);
+ } else {
+ log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
+ }
break;
}
}
@@ -336,26 +345,24 @@ impl State {
pub fn build_server(state: &Arc<AppState>, rpc: &Arc<Peer>) -> Arc<Server> {
ServerBuilder::default()
- // .on_message(share_worktree)
- // .on_message(join_worktree)
- // .on_message(update_worktree)
- // .on_message(close_worktree)
- // .on_message(open_buffer)
- // .on_message(close_buffer)
- // .on_message(update_buffer)
- // .on_message(buffer_saved)
- // .on_message(save_buffer)
- // .on_message(get_channels)
- // .on_message(get_users)
- // .on_message(join_channel)
- // .on_message(send_channel_message)
+ .on_message(share_worktree)
+ .on_message(join_worktree)
+ .on_message(update_worktree)
+ .on_message(close_worktree)
+ .on_message(open_buffer)
+ .on_message(close_buffer)
+ .on_message(update_buffer)
+ .on_message(buffer_saved)
+ .on_message(save_buffer)
+ .on_message(get_channels)
+ .on_message(get_users)
+ .on_message(join_channel)
+ .on_message(send_channel_message)
.build(rpc, state)
}
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let server = build_server(app.state(), rpc);
-
- let rpc = rpc.clone();
app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
let user_id = request.ext::<UserId>().copied();
let server = server.clone();
@@ -399,11 +406,10 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
}
async fn share_worktree(
- mut request: TypedEnvelope<proto::ShareWorktree>,
- rpc: &Arc<Peer>,
- state: &Arc<AppState>,
+ mut request: Box<TypedEnvelope<proto::ShareWorktree>>,
+ server: Arc<Server>,
) -> tide::Result<()> {
- let mut state = state.rpc.write().await;
+ let mut state = server.state.rpc.write().await;
let worktree_id = state.next_worktree_id;
state.next_worktree_id += 1;
let access_token = random_token();
@@ -428,26 +434,27 @@ async fn share_worktree(
},
);
- rpc.respond(
- request.receipt(),
- proto::ShareWorktreeResponse {
- worktree_id,
- access_token,
- },
- )
- .await?;
+ server
+ .rpc
+ .respond(
+ request.receipt(),
+ proto::ShareWorktreeResponse {
+ worktree_id,
+ access_token,
+ },
+ )
+ .await?;
Ok(())
}
async fn join_worktree(
- request: TypedEnvelope<proto::OpenWorktree>,
- rpc: &Arc<Peer>,
- state: &Arc<AppState>,
+ request: Box<TypedEnvelope<proto::OpenWorktree>>,
+ server: Arc<Server>,
) -> tide::Result<()> {
let worktree_id = request.payload.worktree_id;
let access_token = &request.payload.access_token;
- let mut state = state.rpc.write().await;
+ let mut state = server.state.rpc.write().await;
if let Some((peer_replica_id, worktree)) =
state.join_worktree(request.sender_id, worktree_id, access_token)
{
@@ -468,7 +475,7 @@ async fn join_worktree(
}
broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
- rpc.send(
+ server.rpc.send(
conn_id,
proto::AddPeer {
worktree_id,
@@ -480,42 +487,45 @@ async fn join_worktree(
)
})
.await?;
- rpc.respond(
- request.receipt(),
- proto::OpenWorktreeResponse {
- worktree_id,
- worktree: Some(proto::Worktree {
- root_name: worktree.root_name.clone(),
- entries: worktree.entries.values().cloned().collect(),
- }),
- replica_id: peer_replica_id as u32,
- peers,
- },
- )
- .await?;
+ server
+ .rpc
+ .respond(
+ request.receipt(),
+ proto::OpenWorktreeResponse {
+ worktree_id,
+ worktree: Some(proto::Worktree {
+ root_name: worktree.root_name.clone(),
+ entries: worktree.entries.values().cloned().collect(),
+ }),
+ replica_id: peer_replica_id as u32,
+ peers,
+ },
+ )
+ .await?;
} else {
- rpc.respond(
- request.receipt(),
- proto::OpenWorktreeResponse {
- worktree_id,
- worktree: None,
- replica_id: 0,
- peers: Vec::new(),
- },
- )
- .await?;
+ server
+ .rpc
+ .respond(
+ request.receipt(),
+ proto::OpenWorktreeResponse {
+ worktree_id,
+ worktree: None,
+ replica_id: 0,
+ peers: Vec::new(),
+ },
+ )
+ .await?;
}
Ok(())
}
async fn update_worktree(
- request: TypedEnvelope<proto::UpdateWorktree>,
- rpc: &Arc<Peer>,
- state: &Arc<AppState>,
+ request: Box<TypedEnvelope<proto::UpdateWorktree>>,
+ server: Arc<Server>,
) -> tide::Result<()> {
{
- let mut state = state.rpc.write().await;
+ let mut state = server.state.rpc.write().await;
let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
for entry_id in &request.payload.removed_entries {
worktree.entries.remove(&entry_id);
@@ -526,18 +536,17 @@ async fn update_worktree(
}
}
- broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await?;
+ broadcast_in_worktree(request.payload.worktree_id, &request, &server).await?;
Ok(())
}
async fn close_worktree(
- request: TypedEnvelope<proto::CloseWorktree>,
- rpc: &Arc<Peer>,
- state: &Arc<AppState>,
+ request: Box<TypedEnvelope<proto::CloseWorktree>>,
+ server: Arc<Server>,
) -> tide::Result<()> {
let connection_ids;
{
- let mut state = state.rpc.write().await;
+ let mut state = server.state.rpc.write().await;
let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
connection_ids = worktree.connection_ids();
if worktree.host_connection_id == Some(request.sender_id) {
@@ -548,7 +557,7 @@ async fn close_worktree(
}
broadcast(request.sender_id, connection_ids, |conn_id| {
- rpc.send(
+ server.rpc.send(
conn_id,
proto::RemovePeer {
worktree_id: request.payload.worktree_id,
@@ -562,53 +571,55 @@ async fn close_worktree(
}
async fn open_buffer(
- request: TypedEnvelope<proto::OpenBuffer>,
- rpc: &Arc<Peer>,
- state: &Arc<AppState>,
+ request: Box<TypedEnvelope<proto::OpenBuffer>>,
+ server: Arc<Server>,
) -> tide::Result<()> {
let receipt = request.receipt();
let worktree_id = request.payload.worktree_id;
- let host_connection_id = state
+ let host_connection_id = server
+ .state
.rpc
.read()
.await
.read_worktree(worktree_id, request.sender_id)?
.host_connection_id()?;
- let response = rpc
+ let response = server
+ .rpc
.forward_request(request.sender_id, host_connection_id, request.payload)
.await?;
- rpc.respond(receipt, response).await?;
+ server.rpc.respond(receipt, response).await?;
Ok(())
}
async fn close_buffer(
- request: TypedEnvelope<proto::CloseBuffer>,
- rpc: &Arc<Peer>,
- state: &Arc<AppState>,
+ request: Box<TypedEnvelope<proto::CloseBuffer>>,
+ server: Arc<Server>,
) -> tide::Result<()> {
- let host_connection_id = state
+ let host_connection_id = server
+ .state
.rpc
.read()
.await
.read_worktree(request.payload.worktree_id, request.sender_id)?
.host_connection_id()?;
- rpc.forward_send(request.sender_id, host_connection_id, request.payload)
+ server
+ .rpc
+ .forward_send(request.sender_id, host_connection_id, request.payload)
.await?;
Ok(())
}
async fn save_buffer(
- request: TypedEnvelope<proto::SaveBuffer>,
- rpc: &Arc<Peer>,
- state: &Arc<AppState>,
+ request: Box<TypedEnvelope<proto::SaveBuffer>>,
+ server: Arc<Server>,
) -> tide::Result<()> {
let host;
let guests;
{
- let state = state.rpc.read().await;
+ let state = server.state.rpc.read().await;
let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
host = worktree.host_connection_id()?;
guests = worktree
@@ -620,17 +631,19 @@ async fn save_buffer(
let sender = request.sender_id;
let receipt = request.receipt();
- let response = rpc
+ let response = server
+ .rpc
.forward_request(sender, host, request.payload.clone())
.await?;
broadcast(host, guests, |conn_id| {
let response = response.clone();
+ let server = &server;
async move {
if conn_id == sender {
- rpc.respond(receipt, response).await
+ server.rpc.respond(receipt, response).await
} else {
- rpc.forward_send(host, conn_id, response).await
+ server.rpc.forward_send(host, conn_id, response).await
}
}
})
@@ -640,61 +653,62 @@ async fn save_buffer(
}
async fn update_buffer(
- request: TypedEnvelope<proto::UpdateBuffer>,
- rpc: &Arc<Peer>,
- state: &Arc<AppState>,
+ request: Box<TypedEnvelope<proto::UpdateBuffer>>,
+ server: Arc<Server>,
) -> tide::Result<()> {
- broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
+ broadcast_in_worktree(request.payload.worktree_id, &request, &server).await
}
async fn buffer_saved(
- request: TypedEnvelope<proto::BufferSaved>,
- rpc: &Arc<Peer>,
- state: &Arc<AppState>,
+ request: Box<TypedEnvelope<proto::BufferSaved>>,
+ server: Arc<Server>,
) -> tide::Result<()> {
- broadcast_in_worktree(request.payload.worktree_id, request, rpc, state).await
+ broadcast_in_worktree(request.payload.worktree_id, &request, &server).await
}
async fn get_channels(
- request: TypedEnvelope<proto::GetChannels>,
- rpc: &Arc<Peer>,
- state: &Arc<AppState>,
+ request: Box<TypedEnvelope<proto::GetChannels>>,
+ server: Arc<Server>,
) -> tide::Result<()> {
- let user_id = state
+ let user_id = server
+ .state
.rpc
.read()
.await
.user_id_for_connection(request.sender_id)?;
- let channels = state.db.get_channels_for_user(user_id).await?;
- rpc.respond(
- request.receipt(),
- proto::GetChannelsResponse {
- channels: channels
- .into_iter()
- .map(|chan| proto::Channel {
- id: chan.id.to_proto(),
- name: chan.name,
- })
- .collect(),
- },
- )
- .await?;
+ let channels = server.state.db.get_channels_for_user(user_id).await?;
+ server
+ .rpc
+ .respond(
+ request.receipt(),
+ proto::GetChannelsResponse {
+ channels: channels
+ .into_iter()
+ .map(|chan| proto::Channel {
+ id: chan.id.to_proto(),
+ name: chan.name,
+ })
+ .collect(),
+ },
+ )
+ .await?;
Ok(())
}
async fn get_users(
- request: TypedEnvelope<proto::GetUsers>,
- rpc: &Arc<Peer>,
- state: &Arc<AppState>,
+ request: Box<TypedEnvelope<proto::GetUsers>>,
+ server: Arc<Server>,
) -> tide::Result<()> {
- let user_id = state
+ let user_id = server
+ .state
.rpc
.read()
.await
.user_id_for_connection(request.sender_id)?;
let receipt = request.receipt();
let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
- let users = state
+ let users = server
+ .state
.db
.get_users_by_ids(user_id, user_ids)
.await?
@@ -705,23 +719,26 @@ async fn get_users(
avatar_url: String::new(),
})
.collect();
- rpc.respond(receipt, proto::GetUsersResponse { users })
+ server
+ .rpc
+ .respond(receipt, proto::GetUsersResponse { users })
.await?;
Ok(())
}
async fn join_channel(
- request: TypedEnvelope<proto::JoinChannel>,
- rpc: &Arc<Peer>,
- state: &Arc<AppState>,
+ request: Box<TypedEnvelope<proto::JoinChannel>>,
+ server: Arc<Server>,
) -> tide::Result<()> {
- let user_id = state
+ let user_id = server
+ .state
.rpc
.read()
.await
.user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
- if !state
+ if !server
+ .state
.db
.can_user_access_channel(user_id, channel_id)
.await?
@@ -729,12 +746,14 @@ async fn join_channel(
Err(anyhow!("access denied"))?;
}
- state
+ server
+ .state
.rpc
.write()
.await
.join_channel(request.sender_id, channel_id);
- let messages = state
+ let messages = server
+ .state
.db
.get_recent_channel_messages(channel_id, 50)
.await?
@@ -746,21 +765,22 @@ async fn join_channel(
sender_id: msg.sender_id.to_proto(),
})
.collect();
- rpc.respond(request.receipt(), proto::JoinChannelResponse { messages })
+ server
+ .rpc
+ .respond(request.receipt(), proto::JoinChannelResponse { messages })
.await?;
Ok(())
}
async fn send_channel_message(
- request: TypedEnvelope<proto::SendChannelMessage>,
- peer: &Arc<Peer>,
- app: &Arc<AppState>,
+ request: Box<TypedEnvelope<proto::SendChannelMessage>>,
+ server: Arc<Server>,
) -> tide::Result<()> {
let channel_id = ChannelId::from_proto(request.payload.channel_id);
let user_id;
let connection_ids;
{
- let state = app.rpc.read().await;
+ let state = server.state.rpc.read().await;
user_id = state.user_id_for_connection(request.sender_id)?;
if let Some(channel) = state.channels.get(&channel_id) {
connection_ids = channel.connection_ids();
@@ -770,7 +790,8 @@ async fn send_channel_message(
}
let timestamp = OffsetDateTime::now_utc();
- let message_id = app
+ let message_id = server
+ .state
.db
.create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
.await?;
@@ -784,7 +805,7 @@ async fn send_channel_message(
}),
};
broadcast(request.sender_id, connection_ids, |conn_id| {
- peer.send(conn_id, message.clone())
+ server.rpc.send(conn_id, message.clone())
})
.await?;
@@ -793,11 +814,11 @@ async fn send_channel_message(
async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
worktree_id: u64,
- request: TypedEnvelope<T>,
- rpc: &Arc<Peer>,
- state: &Arc<AppState>,
+ request: &TypedEnvelope<T>,
+ server: &Arc<Server>,
) -> tide::Result<()> {
- let connection_ids = state
+ let connection_ids = server
+ .state
.rpc
.read()
.await
@@ -805,7 +826,9 @@ async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
.connection_ids();
broadcast(request.sender_id, connection_ids, |conn_id| {
- rpc.forward_send(request.sender_id, conn_id, request.payload.clone())
+ server
+ .rpc
+ .forward_send(request.sender_id, conn_id, request.payload.clone())
})
.await?;