@@ -1,10 +1,10 @@
use super::{
- auth::{self, PeerExt as _},
+ auth,
db::{ChannelId, UserId},
AppState,
};
use anyhow::anyhow;
-use async_std::task;
+use async_std::{sync::RwLock, task};
use async_tungstenite::{
tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
WebSocketStream,
@@ -13,7 +13,7 @@ use futures::{future::BoxFuture, FutureExt};
use postage::prelude::Stream as _;
use sha1::{Digest as _, Sha1};
use std::{
- any::{Any, TypeId},
+ any::TypeId,
collections::{HashMap, HashSet},
future::Future,
mem,
@@ -38,51 +38,90 @@ type ReplicaId = u16;
type MessageHandler = Box<
dyn Send
+ Sync
- + Fn(Box<dyn AnyTypedEnvelope>, Arc<Server>) -> BoxFuture<'static, tide::Result<()>>,
+ + Fn(Arc<Server>, Box<dyn AnyTypedEnvelope>) -> BoxFuture<'static, tide::Result<()>>,
>;
-#[derive(Default)]
-struct ServerBuilder {
+pub struct Server {
+ peer: Arc<Peer>,
+ state: RwLock<ServerState>,
+ app_state: Arc<AppState>,
handlers: HashMap<TypeId, MessageHandler>,
}
-impl ServerBuilder {
- pub fn on_message<F, Fut, M>(mut self, handler: F) -> Self
+#[derive(Default)]
+struct ServerState {
+ connections: HashMap<ConnectionId, Connection>,
+ pub worktrees: HashMap<u64, Worktree>,
+ channels: HashMap<ChannelId, Channel>,
+ next_worktree_id: u64,
+}
+
+struct Connection {
+ user_id: UserId,
+ worktrees: HashSet<u64>,
+ channels: HashSet<ChannelId>,
+}
+
+struct Worktree {
+ host_connection_id: Option<ConnectionId>,
+ guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
+ active_replica_ids: HashSet<ReplicaId>,
+ access_token: String,
+ root_name: String,
+ entries: HashMap<u64, proto::Entry>,
+}
+
+#[derive(Default)]
+struct Channel {
+ connection_ids: HashSet<ConnectionId>,
+}
+
+impl Server {
+ pub fn new(app_state: Arc<AppState>, peer: Arc<Peer>) -> Arc<Self> {
+ let mut server = Server {
+ peer,
+ app_state,
+ state: Default::default(),
+ handlers: Default::default(),
+ };
+
+ server
+ .add_handler(Server::share_worktree)
+ .add_handler(Server::join_worktree)
+ .add_handler(Server::update_worktree)
+ .add_handler(Server::close_worktree)
+ .add_handler(Server::open_buffer)
+ .add_handler(Server::close_buffer)
+ .add_handler(Server::update_buffer)
+ .add_handler(Server::buffer_saved)
+ .add_handler(Server::save_buffer)
+ .add_handler(Server::get_channels)
+ .add_handler(Server::get_users)
+ .add_handler(Server::join_channel)
+ .add_handler(Server::send_channel_message);
+
+ Arc::new(server)
+ }
+
+ fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
where
- F: 'static + Send + Sync + Fn(Box<TypedEnvelope<M>>, Arc<Server>) -> Fut,
+ F: 'static + Send + Sync + Fn(Arc<Self>, TypedEnvelope<M>) -> Fut,
Fut: 'static + Send + Future<Output = tide::Result<()>>,
M: EnvelopedMessage,
{
let prev_handler = self.handlers.insert(
TypeId::of::<M>(),
- Box::new(move |envelope, server| {
+ Box::new(move |server, envelope| {
let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
- (handler)(envelope, server).boxed()
+ (handler)(server, *envelope).boxed()
}),
);
if prev_handler.is_some() {
panic!("registered a handler for the same message twice");
}
-
self
}
- pub fn build(self, rpc: &Arc<Peer>, state: &Arc<AppState>) -> Arc<Server> {
- Arc::new(Server {
- rpc: rpc.clone(),
- state: state.clone(),
- handlers: self.handlers,
- })
- }
-}
-
-pub struct Server {
- rpc: Arc<Peer>,
- state: Arc<AppState>,
- handlers: HashMap<TypeId, MessageHandler>,
-}
-
-impl Server {
pub fn handle_connection<Conn>(
self: &Arc<Self>,
connection: Conn,
@@ -99,12 +138,8 @@ impl Server {
let this = self.clone();
async move {
let (connection_id, handle_io, mut incoming_rx) =
- this.rpc.add_connection(connection).await;
- this.state
- .rpc
- .write()
- .await
- .add_connection(connection_id, user_id);
+ this.peer.add_connection(connection).await;
+ this.add_connection(connection_id, user_id).await;
let handle_io = handle_io.fuse();
futures::pin_mut!(handle_io);
@@ -117,7 +152,7 @@ impl Server {
let start_time = Instant::now();
log::info!("RPC message received: {}", message.payload_type_name());
if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
- if let Err(err) = (handler)(message, this.clone()).await {
+ if let Err(err) = (handler)(this.clone(), message).await {
log::error!("error handling message: {:?}", err);
} else {
log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
@@ -139,67 +174,36 @@ impl Server {
}
}
- if let Err(err) = this.rpc.sign_out(connection_id, &this.state).await {
+ if let Err(err) = this.sign_out(connection_id).await {
log::error!("error signing out connection {:?} - {:?}", addr, err);
}
}
}
-}
-
-#[derive(Default)]
-pub struct State {
- connections: HashMap<ConnectionId, Connection>,
- pub worktrees: HashMap<u64, Worktree>,
- channels: HashMap<ChannelId, Channel>,
- next_worktree_id: u64,
-}
-
-struct Connection {
- user_id: UserId,
- worktrees: HashSet<u64>,
- channels: HashSet<ChannelId>,
-}
-
-pub struct Worktree {
- host_connection_id: Option<ConnectionId>,
- guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
- active_replica_ids: HashSet<ReplicaId>,
- access_token: String,
- root_name: String,
- entries: HashMap<u64, proto::Entry>,
-}
-
-#[derive(Default)]
-struct Channel {
- connection_ids: HashSet<ConnectionId>,
-}
-
-impl Worktree {
- pub fn connection_ids(&self) -> Vec<ConnectionId> {
- self.guest_connection_ids
- .keys()
- .copied()
- .chain(self.host_connection_id)
- .collect()
- }
-
- fn host_connection_id(&self) -> tide::Result<ConnectionId> {
- Ok(self
- .host_connection_id
- .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
- }
-}
-impl Channel {
- fn connection_ids(&self) -> Vec<ConnectionId> {
- self.connection_ids.iter().copied().collect()
+ async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
+ self.peer.disconnect(connection_id).await;
+ let worktree_ids = self.remove_connection(connection_id).await;
+ for worktree_id in worktree_ids {
+ let state = self.state.read().await;
+ if let Some(worktree) = state.worktrees.get(&worktree_id) {
+ broadcast(connection_id, worktree.connection_ids(), |conn_id| {
+ self.peer.send(
+ conn_id,
+ proto::RemovePeer {
+ worktree_id,
+ peer_id: connection_id.0,
+ },
+ )
+ })
+ .await?;
+ }
+ }
+ Ok(())
}
-}
-impl State {
// Add a new connection associated with a given user.
- pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
- self.connections.insert(
+ async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
+ self.state.write().await.connections.insert(
connection_id,
Connection {
user_id,
@@ -210,16 +214,17 @@ impl State {
}
// Remove the given connection and its association with any worktrees.
- pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Vec<u64> {
+ async fn remove_connection(&self, connection_id: ConnectionId) -> Vec<u64> {
let mut worktree_ids = Vec::new();
- if let Some(connection) = self.connections.remove(&connection_id) {
+ let mut state = self.state.write().await;
+ if let Some(connection) = state.connections.remove(&connection_id) {
for channel_id in connection.channels {
- if let Some(channel) = self.channels.get_mut(&channel_id) {
+ if let Some(channel) = state.channels.get_mut(&channel_id) {
channel.connection_ids.remove(&connection_id);
}
}
for worktree_id in connection.worktrees {
- if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
+ if let Some(worktree) = state.worktrees.get_mut(&worktree_id) {
if worktree.host_connection_id == Some(connection_id) {
worktree_ids.push(worktree_id);
} else if let Some(replica_id) =
@@ -234,6 +239,444 @@ impl State {
worktree_ids
}
+ async fn share_worktree(
+ self: Arc<Server>,
+ mut request: TypedEnvelope<proto::ShareWorktree>,
+ ) -> tide::Result<()> {
+ let mut state = self.state.write().await;
+ let worktree_id = state.next_worktree_id;
+ state.next_worktree_id += 1;
+ let access_token = random_token();
+ let worktree = request
+ .payload
+ .worktree
+ .as_mut()
+ .ok_or_else(|| anyhow!("missing worktree"))?;
+ let entries = mem::take(&mut worktree.entries)
+ .into_iter()
+ .map(|entry| (entry.id, entry))
+ .collect();
+ state.worktrees.insert(
+ worktree_id,
+ Worktree {
+ host_connection_id: Some(request.sender_id),
+ guest_connection_ids: Default::default(),
+ active_replica_ids: Default::default(),
+ access_token: access_token.clone(),
+ root_name: mem::take(&mut worktree.root_name),
+ entries,
+ },
+ );
+
+ self.peer
+ .respond(
+ request.receipt(),
+ proto::ShareWorktreeResponse {
+ worktree_id,
+ access_token,
+ },
+ )
+ .await?;
+ Ok(())
+ }
+
+ async fn join_worktree(
+ self: Arc<Server>,
+ request: TypedEnvelope<proto::OpenWorktree>,
+ ) -> tide::Result<()> {
+ let worktree_id = request.payload.worktree_id;
+ let access_token = &request.payload.access_token;
+
+ let mut state = self.state.write().await;
+ if let Some((peer_replica_id, worktree)) =
+ state.join_worktree(request.sender_id, worktree_id, access_token)
+ {
+ let mut peers = Vec::new();
+ if let Some(host_connection_id) = worktree.host_connection_id {
+ peers.push(proto::Peer {
+ peer_id: host_connection_id.0,
+ replica_id: 0,
+ });
+ }
+ for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
+ if *peer_conn_id != request.sender_id {
+ peers.push(proto::Peer {
+ peer_id: peer_conn_id.0,
+ replica_id: *peer_replica_id as u32,
+ });
+ }
+ }
+
+ broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
+ self.peer.send(
+ conn_id,
+ proto::AddPeer {
+ worktree_id,
+ peer: Some(proto::Peer {
+ peer_id: request.sender_id.0,
+ replica_id: peer_replica_id as u32,
+ }),
+ },
+ )
+ })
+ .await?;
+ self.peer
+ .respond(
+ request.receipt(),
+ proto::OpenWorktreeResponse {
+ worktree_id,
+ worktree: Some(proto::Worktree {
+ root_name: worktree.root_name.clone(),
+ entries: worktree.entries.values().cloned().collect(),
+ }),
+ replica_id: peer_replica_id as u32,
+ peers,
+ },
+ )
+ .await?;
+ } else {
+ self.peer
+ .respond(
+ request.receipt(),
+ proto::OpenWorktreeResponse {
+ worktree_id,
+ worktree: None,
+ replica_id: 0,
+ peers: Vec::new(),
+ },
+ )
+ .await?;
+ }
+
+ Ok(())
+ }
+
+ async fn update_worktree(
+ self: Arc<Server>,
+ request: TypedEnvelope<proto::UpdateWorktree>,
+ ) -> tide::Result<()> {
+ {
+ let mut state = self.state.write().await;
+ let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
+ for entry_id in &request.payload.removed_entries {
+ worktree.entries.remove(&entry_id);
+ }
+
+ for entry in &request.payload.updated_entries {
+ worktree.entries.insert(entry.id, entry.clone());
+ }
+ }
+
+ self.broadcast_in_worktree(request.payload.worktree_id, &request)
+ .await?;
+ Ok(())
+ }
+
+ async fn close_worktree(
+ self: Arc<Server>,
+ request: TypedEnvelope<proto::CloseWorktree>,
+ ) -> tide::Result<()> {
+ let connection_ids;
+ {
+ let mut state = self.state.write().await;
+ let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
+ connection_ids = worktree.connection_ids();
+ if worktree.host_connection_id == Some(request.sender_id) {
+ worktree.host_connection_id = None;
+ } else if let Some(replica_id) =
+ worktree.guest_connection_ids.remove(&request.sender_id)
+ {
+ worktree.active_replica_ids.remove(&replica_id);
+ }
+ }
+
+ broadcast(request.sender_id, connection_ids, |conn_id| {
+ self.peer.send(
+ conn_id,
+ proto::RemovePeer {
+ worktree_id: request.payload.worktree_id,
+ peer_id: request.sender_id.0,
+ },
+ )
+ })
+ .await?;
+
+ Ok(())
+ }
+
+ async fn open_buffer(
+ self: Arc<Server>,
+ request: TypedEnvelope<proto::OpenBuffer>,
+ ) -> tide::Result<()> {
+ let receipt = request.receipt();
+ let worktree_id = request.payload.worktree_id;
+ let host_connection_id = self
+ .state
+ .read()
+ .await
+ .read_worktree(worktree_id, request.sender_id)?
+ .host_connection_id()?;
+
+ let response = self
+ .peer
+ .forward_request(request.sender_id, host_connection_id, request.payload)
+ .await?;
+ self.peer.respond(receipt, response).await?;
+ Ok(())
+ }
+
+ async fn close_buffer(
+ self: Arc<Server>,
+ request: TypedEnvelope<proto::CloseBuffer>,
+ ) -> tide::Result<()> {
+ let host_connection_id = self
+ .state
+ .read()
+ .await
+ .read_worktree(request.payload.worktree_id, request.sender_id)?
+ .host_connection_id()?;
+
+ self.peer
+ .forward_send(request.sender_id, host_connection_id, request.payload)
+ .await?;
+
+ Ok(())
+ }
+
+ async fn save_buffer(
+ self: Arc<Server>,
+ request: TypedEnvelope<proto::SaveBuffer>,
+ ) -> tide::Result<()> {
+ let host;
+ let guests;
+ {
+ let state = self.state.read().await;
+ let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
+ host = worktree.host_connection_id()?;
+ guests = worktree
+ .guest_connection_ids
+ .keys()
+ .copied()
+ .collect::<Vec<_>>();
+ }
+
+ let sender = request.sender_id;
+ let receipt = request.receipt();
+ let response = self
+ .peer
+ .forward_request(sender, host, request.payload.clone())
+ .await?;
+
+ broadcast(host, guests, |conn_id| {
+ let response = response.clone();
+ let peer = &self.peer;
+ async move {
+ if conn_id == sender {
+ peer.respond(receipt, response).await
+ } else {
+ peer.forward_send(host, conn_id, response).await
+ }
+ }
+ })
+ .await?;
+
+ Ok(())
+ }
+
+ async fn update_buffer(
+ self: Arc<Server>,
+ request: TypedEnvelope<proto::UpdateBuffer>,
+ ) -> tide::Result<()> {
+ self.broadcast_in_worktree(request.payload.worktree_id, &request)
+ .await
+ }
+
+ async fn buffer_saved(
+ self: Arc<Server>,
+ request: TypedEnvelope<proto::BufferSaved>,
+ ) -> tide::Result<()> {
+ self.broadcast_in_worktree(request.payload.worktree_id, &request)
+ .await
+ }
+
+ async fn get_channels(
+ self: Arc<Server>,
+ request: TypedEnvelope<proto::GetChannels>,
+ ) -> tide::Result<()> {
+ let user_id = self
+ .state
+ .read()
+ .await
+ .user_id_for_connection(request.sender_id)?;
+ let channels = self.app_state.db.get_channels_for_user(user_id).await?;
+ self.peer
+ .respond(
+ request.receipt(),
+ proto::GetChannelsResponse {
+ channels: channels
+ .into_iter()
+ .map(|chan| proto::Channel {
+ id: chan.id.to_proto(),
+ name: chan.name,
+ })
+ .collect(),
+ },
+ )
+ .await?;
+ Ok(())
+ }
+
+ async fn get_users(
+ self: Arc<Server>,
+ request: TypedEnvelope<proto::GetUsers>,
+ ) -> tide::Result<()> {
+ let user_id = self
+ .state
+ .read()
+ .await
+ .user_id_for_connection(request.sender_id)?;
+ let receipt = request.receipt();
+ let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
+ let users = self
+ .app_state
+ .db
+ .get_users_by_ids(user_id, user_ids)
+ .await?
+ .into_iter()
+ .map(|user| proto::User {
+ id: user.id.to_proto(),
+ github_login: user.github_login,
+ avatar_url: String::new(),
+ })
+ .collect();
+ self.peer
+ .respond(receipt, proto::GetUsersResponse { users })
+ .await?;
+ Ok(())
+ }
+
+ async fn join_channel(
+ self: Arc<Self>,
+ request: TypedEnvelope<proto::JoinChannel>,
+ ) -> tide::Result<()> {
+ let user_id = self
+ .state
+ .read()
+ .await
+ .user_id_for_connection(request.sender_id)?;
+ let channel_id = ChannelId::from_proto(request.payload.channel_id);
+ if !self
+ .app_state
+ .db
+ .can_user_access_channel(user_id, channel_id)
+ .await?
+ {
+ Err(anyhow!("access denied"))?;
+ }
+
+ self.state
+ .write()
+ .await
+ .join_channel(request.sender_id, channel_id);
+ let messages = self
+ .app_state
+ .db
+ .get_recent_channel_messages(channel_id, 50)
+ .await?
+ .into_iter()
+ .map(|msg| proto::ChannelMessage {
+ id: msg.id.to_proto(),
+ body: msg.body,
+ timestamp: msg.sent_at.unix_timestamp() as u64,
+ sender_id: msg.sender_id.to_proto(),
+ })
+ .collect();
+ self.peer
+ .respond(request.receipt(), proto::JoinChannelResponse { messages })
+ .await?;
+ Ok(())
+ }
+
+ async fn send_channel_message(
+ self: Arc<Self>,
+ request: TypedEnvelope<proto::SendChannelMessage>,
+ ) -> tide::Result<()> {
+ let channel_id = ChannelId::from_proto(request.payload.channel_id);
+ let user_id;
+ let connection_ids;
+ {
+ let state = self.state.read().await;
+ user_id = state.user_id_for_connection(request.sender_id)?;
+ if let Some(channel) = state.channels.get(&channel_id) {
+ connection_ids = channel.connection_ids();
+ } else {
+ return Ok(());
+ }
+ }
+
+ let timestamp = OffsetDateTime::now_utc();
+ let message_id = self
+ .app_state
+ .db
+ .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
+ .await?;
+ let message = proto::ChannelMessageSent {
+ channel_id: channel_id.to_proto(),
+ message: Some(proto::ChannelMessage {
+ sender_id: user_id.to_proto(),
+ id: message_id.to_proto(),
+ body: request.payload.body,
+ timestamp: timestamp.unix_timestamp() as u64,
+ }),
+ };
+ broadcast(request.sender_id, connection_ids, |conn_id| {
+ self.peer.send(conn_id, message.clone())
+ })
+ .await?;
+
+ Ok(())
+ }
+
+ async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
+ &self,
+ worktree_id: u64,
+ message: &TypedEnvelope<T>,
+ ) -> tide::Result<()> {
+ let connection_ids = self
+ .state
+ .read()
+ .await
+ .read_worktree(worktree_id, message.sender_id)?
+ .connection_ids();
+
+ broadcast(message.sender_id, connection_ids, |conn_id| {
+ self.peer
+ .forward_send(message.sender_id, conn_id, message.payload.clone())
+ })
+ .await?;
+
+ Ok(())
+ }
+}
+
+pub async fn broadcast<F, T>(
+ sender_id: ConnectionId,
+ receiver_ids: Vec<ConnectionId>,
+ mut f: F,
+) -> anyhow::Result<()>
+where
+ F: FnMut(ConnectionId) -> T,
+ T: Future<Output = anyhow::Result<()>>,
+{
+ let futures = receiver_ids
+ .into_iter()
+ .filter(|id| *id != sender_id)
+ .map(|id| f(id));
+ futures::future::try_join_all(futures).await?;
+ Ok(())
+}
+
+impl ServerState {
fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
if let Some(connection) = self.connections.get_mut(&connection_id) {
connection.channels.insert(channel_id);
@@ -245,8 +688,16 @@ impl State {
}
}
+ fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
+ Ok(self
+ .connections
+ .get(&connection_id)
+ .ok_or_else(|| anyhow!("unknown connection"))?
+ .user_id)
+ }
+
// Add the given connection as a guest of the given worktree
- pub fn join_worktree(
+ fn join_worktree(
&mut self,
connection_id: ConnectionId,
worktree_id: u64,
@@ -275,14 +726,6 @@ impl State {
}
}
- fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
- Ok(self
- .connections
- .get(&connection_id)
- .ok_or_else(|| anyhow!("unknown connection"))?
- .user_id)
- }
-
fn read_worktree(
&self,
worktree_id: u64,
@@ -330,26 +773,30 @@ impl State {
}
}
-pub fn build_server(state: &Arc<AppState>, rpc: &Arc<Peer>) -> Arc<Server> {
- ServerBuilder::default()
- .on_message(share_worktree)
- .on_message(join_worktree)
- .on_message(update_worktree)
- .on_message(close_worktree)
- .on_message(open_buffer)
- .on_message(close_buffer)
- .on_message(update_buffer)
- .on_message(buffer_saved)
- .on_message(save_buffer)
- .on_message(get_channels)
- .on_message(get_users)
- .on_message(join_channel)
- .on_message(send_channel_message)
- .build(rpc, state)
+impl Worktree {
+ pub fn connection_ids(&self) -> Vec<ConnectionId> {
+ self.guest_connection_ids
+ .keys()
+ .copied()
+ .chain(self.host_connection_id)
+ .collect()
+ }
+
+ fn host_connection_id(&self) -> tide::Result<ConnectionId> {
+ Ok(self
+ .host_connection_id
+ .ok_or_else(|| anyhow!("host disconnected from worktree"))?)
+ }
+}
+
+impl Channel {
+ fn connection_ids(&self) -> Vec<ConnectionId> {
+ self.connection_ids.iter().copied().collect()
+ }
}
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
- let server = build_server(app.state(), rpc);
+ let server = Server::new(app.state().clone(), rpc.clone());
app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
let user_id = request.ext::<UserId>().copied();
let server = server.clone();
@@ -392,453 +839,6 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
});
}
-async fn share_worktree(
- mut request: Box<TypedEnvelope<proto::ShareWorktree>>,
- server: Arc<Server>,
-) -> tide::Result<()> {
- let mut state = server.state.rpc.write().await;
- let worktree_id = state.next_worktree_id;
- state.next_worktree_id += 1;
- let access_token = random_token();
- let worktree = request
- .payload
- .worktree
- .as_mut()
- .ok_or_else(|| anyhow!("missing worktree"))?;
- let entries = mem::take(&mut worktree.entries)
- .into_iter()
- .map(|entry| (entry.id, entry))
- .collect();
- state.worktrees.insert(
- worktree_id,
- Worktree {
- host_connection_id: Some(request.sender_id),
- guest_connection_ids: Default::default(),
- active_replica_ids: Default::default(),
- access_token: access_token.clone(),
- root_name: mem::take(&mut worktree.root_name),
- entries,
- },
- );
-
- server
- .rpc
- .respond(
- request.receipt(),
- proto::ShareWorktreeResponse {
- worktree_id,
- access_token,
- },
- )
- .await?;
- Ok(())
-}
-
-async fn join_worktree(
- request: Box<TypedEnvelope<proto::OpenWorktree>>,
- server: Arc<Server>,
-) -> tide::Result<()> {
- let worktree_id = request.payload.worktree_id;
- let access_token = &request.payload.access_token;
-
- let mut state = server.state.rpc.write().await;
- if let Some((peer_replica_id, worktree)) =
- state.join_worktree(request.sender_id, worktree_id, access_token)
- {
- let mut peers = Vec::new();
- if let Some(host_connection_id) = worktree.host_connection_id {
- peers.push(proto::Peer {
- peer_id: host_connection_id.0,
- replica_id: 0,
- });
- }
- for (peer_conn_id, peer_replica_id) in &worktree.guest_connection_ids {
- if *peer_conn_id != request.sender_id {
- peers.push(proto::Peer {
- peer_id: peer_conn_id.0,
- replica_id: *peer_replica_id as u32,
- });
- }
- }
-
- broadcast(request.sender_id, worktree.connection_ids(), |conn_id| {
- server.rpc.send(
- conn_id,
- proto::AddPeer {
- worktree_id,
- peer: Some(proto::Peer {
- peer_id: request.sender_id.0,
- replica_id: peer_replica_id as u32,
- }),
- },
- )
- })
- .await?;
- server
- .rpc
- .respond(
- request.receipt(),
- proto::OpenWorktreeResponse {
- worktree_id,
- worktree: Some(proto::Worktree {
- root_name: worktree.root_name.clone(),
- entries: worktree.entries.values().cloned().collect(),
- }),
- replica_id: peer_replica_id as u32,
- peers,
- },
- )
- .await?;
- } else {
- server
- .rpc
- .respond(
- request.receipt(),
- proto::OpenWorktreeResponse {
- worktree_id,
- worktree: None,
- replica_id: 0,
- peers: Vec::new(),
- },
- )
- .await?;
- }
-
- Ok(())
-}
-
-async fn update_worktree(
- request: Box<TypedEnvelope<proto::UpdateWorktree>>,
- server: Arc<Server>,
-) -> tide::Result<()> {
- {
- let mut state = server.state.rpc.write().await;
- let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
- for entry_id in &request.payload.removed_entries {
- worktree.entries.remove(&entry_id);
- }
-
- for entry in &request.payload.updated_entries {
- worktree.entries.insert(entry.id, entry.clone());
- }
- }
-
- broadcast_in_worktree(request.payload.worktree_id, &request, &server).await?;
- Ok(())
-}
-
-async fn close_worktree(
- request: Box<TypedEnvelope<proto::CloseWorktree>>,
- server: Arc<Server>,
-) -> tide::Result<()> {
- let connection_ids;
- {
- let mut state = server.state.rpc.write().await;
- let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
- connection_ids = worktree.connection_ids();
- if worktree.host_connection_id == Some(request.sender_id) {
- worktree.host_connection_id = None;
- } else if let Some(replica_id) = worktree.guest_connection_ids.remove(&request.sender_id) {
- worktree.active_replica_ids.remove(&replica_id);
- }
- }
-
- broadcast(request.sender_id, connection_ids, |conn_id| {
- server.rpc.send(
- conn_id,
- proto::RemovePeer {
- worktree_id: request.payload.worktree_id,
- peer_id: request.sender_id.0,
- },
- )
- })
- .await?;
-
- Ok(())
-}
-
-async fn open_buffer(
- request: Box<TypedEnvelope<proto::OpenBuffer>>,
- server: Arc<Server>,
-) -> tide::Result<()> {
- let receipt = request.receipt();
- let worktree_id = request.payload.worktree_id;
- let host_connection_id = server
- .state
- .rpc
- .read()
- .await
- .read_worktree(worktree_id, request.sender_id)?
- .host_connection_id()?;
-
- let response = server
- .rpc
- .forward_request(request.sender_id, host_connection_id, request.payload)
- .await?;
- server.rpc.respond(receipt, response).await?;
- Ok(())
-}
-
-async fn close_buffer(
- request: Box<TypedEnvelope<proto::CloseBuffer>>,
- server: Arc<Server>,
-) -> tide::Result<()> {
- let host_connection_id = server
- .state
- .rpc
- .read()
- .await
- .read_worktree(request.payload.worktree_id, request.sender_id)?
- .host_connection_id()?;
-
- server
- .rpc
- .forward_send(request.sender_id, host_connection_id, request.payload)
- .await?;
-
- Ok(())
-}
-
-async fn save_buffer(
- request: Box<TypedEnvelope<proto::SaveBuffer>>,
- server: Arc<Server>,
-) -> tide::Result<()> {
- let host;
- let guests;
- {
- let state = server.state.rpc.read().await;
- let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
- host = worktree.host_connection_id()?;
- guests = worktree
- .guest_connection_ids
- .keys()
- .copied()
- .collect::<Vec<_>>();
- }
-
- let sender = request.sender_id;
- let receipt = request.receipt();
- let response = server
- .rpc
- .forward_request(sender, host, request.payload.clone())
- .await?;
-
- broadcast(host, guests, |conn_id| {
- let response = response.clone();
- let server = &server;
- async move {
- if conn_id == sender {
- server.rpc.respond(receipt, response).await
- } else {
- server.rpc.forward_send(host, conn_id, response).await
- }
- }
- })
- .await?;
-
- Ok(())
-}
-
-async fn update_buffer(
- request: Box<TypedEnvelope<proto::UpdateBuffer>>,
- server: Arc<Server>,
-) -> tide::Result<()> {
- broadcast_in_worktree(request.payload.worktree_id, &request, &server).await
-}
-
-async fn buffer_saved(
- request: Box<TypedEnvelope<proto::BufferSaved>>,
- server: Arc<Server>,
-) -> tide::Result<()> {
- broadcast_in_worktree(request.payload.worktree_id, &request, &server).await
-}
-
-async fn get_channels(
- request: Box<TypedEnvelope<proto::GetChannels>>,
- server: Arc<Server>,
-) -> tide::Result<()> {
- let user_id = server
- .state
- .rpc
- .read()
- .await
- .user_id_for_connection(request.sender_id)?;
- let channels = server.state.db.get_channels_for_user(user_id).await?;
- server
- .rpc
- .respond(
- request.receipt(),
- proto::GetChannelsResponse {
- channels: channels
- .into_iter()
- .map(|chan| proto::Channel {
- id: chan.id.to_proto(),
- name: chan.name,
- })
- .collect(),
- },
- )
- .await?;
- Ok(())
-}
-
-async fn get_users(
- request: Box<TypedEnvelope<proto::GetUsers>>,
- server: Arc<Server>,
-) -> tide::Result<()> {
- let user_id = server
- .state
- .rpc
- .read()
- .await
- .user_id_for_connection(request.sender_id)?;
- let receipt = request.receipt();
- let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
- let users = server
- .state
- .db
- .get_users_by_ids(user_id, user_ids)
- .await?
- .into_iter()
- .map(|user| proto::User {
- id: user.id.to_proto(),
- github_login: user.github_login,
- avatar_url: String::new(),
- })
- .collect();
- server
- .rpc
- .respond(receipt, proto::GetUsersResponse { users })
- .await?;
- Ok(())
-}
-
-async fn join_channel(
- request: Box<TypedEnvelope<proto::JoinChannel>>,
- server: Arc<Server>,
-) -> tide::Result<()> {
- let user_id = server
- .state
- .rpc
- .read()
- .await
- .user_id_for_connection(request.sender_id)?;
- let channel_id = ChannelId::from_proto(request.payload.channel_id);
- if !server
- .state
- .db
- .can_user_access_channel(user_id, channel_id)
- .await?
- {
- Err(anyhow!("access denied"))?;
- }
-
- server
- .state
- .rpc
- .write()
- .await
- .join_channel(request.sender_id, channel_id);
- let messages = server
- .state
- .db
- .get_recent_channel_messages(channel_id, 50)
- .await?
- .into_iter()
- .map(|msg| proto::ChannelMessage {
- id: msg.id.to_proto(),
- body: msg.body,
- timestamp: msg.sent_at.unix_timestamp() as u64,
- sender_id: msg.sender_id.to_proto(),
- })
- .collect();
- server
- .rpc
- .respond(request.receipt(), proto::JoinChannelResponse { messages })
- .await?;
- Ok(())
-}
-
-async fn send_channel_message(
- request: Box<TypedEnvelope<proto::SendChannelMessage>>,
- server: Arc<Server>,
-) -> tide::Result<()> {
- let channel_id = ChannelId::from_proto(request.payload.channel_id);
- let user_id;
- let connection_ids;
- {
- let state = server.state.rpc.read().await;
- user_id = state.user_id_for_connection(request.sender_id)?;
- if let Some(channel) = state.channels.get(&channel_id) {
- connection_ids = channel.connection_ids();
- } else {
- return Ok(());
- }
- }
-
- let timestamp = OffsetDateTime::now_utc();
- let message_id = server
- .state
- .db
- .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
- .await?;
- let message = proto::ChannelMessageSent {
- channel_id: channel_id.to_proto(),
- message: Some(proto::ChannelMessage {
- sender_id: user_id.to_proto(),
- id: message_id.to_proto(),
- body: request.payload.body,
- timestamp: timestamp.unix_timestamp() as u64,
- }),
- };
- broadcast(request.sender_id, connection_ids, |conn_id| {
- server.rpc.send(conn_id, message.clone())
- })
- .await?;
-
- Ok(())
-}
-
-async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
- worktree_id: u64,
- request: &TypedEnvelope<T>,
- server: &Arc<Server>,
-) -> tide::Result<()> {
- let connection_ids = server
- .state
- .rpc
- .read()
- .await
- .read_worktree(worktree_id, request.sender_id)?
- .connection_ids();
-
- broadcast(request.sender_id, connection_ids, |conn_id| {
- server
- .rpc
- .forward_send(request.sender_id, conn_id, request.payload.clone())
- })
- .await?;
-
- Ok(())
-}
-
-pub async fn broadcast<F, T>(
- sender_id: ConnectionId,
- receiver_ids: Vec<ConnectionId>,
- mut f: F,
-) -> anyhow::Result<()>
-where
- F: FnMut(ConnectionId) -> T,
- T: Future<Output = anyhow::Result<()>>,
-{
- let futures = receiver_ids
- .into_iter()
- .filter(|id| *id != sender_id)
- .map(|id| f(id));
- futures::future::try_join_all(futures).await?;
- Ok(())
-}
-
fn header_contains_ignore_case<T>(
request: &tide::Request<T>,
header_name: HeaderName,