@@ -112,12 +112,11 @@ impl Server {
addr: String,
user_id: UserId,
) -> impl Future<Output = ()> {
- let this = self.clone();
+ let mut this = self.clone();
async move {
let (connection_id, handle_io, mut incoming_rx) =
this.peer.add_connection(connection).await;
- this.store
- .write()
+ this.state_mut()
.await
.add_connection(connection_id, user_id);
if let Err(err) = this.update_collaborators_for_users(&[user_id]).await {
@@ -167,9 +166,9 @@ impl Server {
}
}
- async fn sign_out(self: &Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
+ async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
self.peer.disconnect(connection_id).await;
- let removed_connection = self.store.write().await.remove_connection(connection_id)?;
+ let removed_connection = self.state_mut().await.remove_connection(connection_id)?;
for (worktree_id, worktree) in removed_connection.hosted_worktrees {
if let Some(share) = worktree.share {
@@ -210,13 +209,12 @@ impl Server {
}
async fn open_worktree(
- self: Arc<Server>,
+ mut self: Arc<Server>,
request: TypedEnvelope<proto::OpenWorktree>,
) -> tide::Result<()> {
let receipt = request.receipt();
let host_user_id = self
- .store
- .read()
+ .state()
.await
.user_id_for_connection(request.sender_id)?;
@@ -238,7 +236,7 @@ impl Server {
}
let collaborator_user_ids = collaborator_user_ids.into_iter().collect::<Vec<_>>();
- let worktree_id = self.store.write().await.add_worktree(Worktree {
+ let worktree_id = self.state_mut().await.add_worktree(Worktree {
host_connection_id: request.sender_id,
collaborator_user_ids: collaborator_user_ids.clone(),
root_name: request.payload.root_name,
@@ -255,13 +253,12 @@ impl Server {
}
async fn close_worktree(
- self: Arc<Server>,
+ mut self: Arc<Server>,
request: TypedEnvelope<proto::CloseWorktree>,
) -> tide::Result<()> {
let worktree_id = request.payload.worktree_id;
let worktree = self
- .store
- .write()
+ .state_mut()
.await
.remove_worktree(worktree_id, request.sender_id)?;
@@ -282,7 +279,7 @@ impl Server {
}
async fn share_worktree(
- self: Arc<Server>,
+ mut self: Arc<Server>,
mut request: TypedEnvelope<proto::ShareWorktree>,
) -> tide::Result<()> {
let worktree = request
@@ -296,8 +293,7 @@ impl Server {
.collect();
let collaborator_user_ids =
- self.store
- .write()
+ self.state_mut()
.await
.share_worktree(worktree.id, request.sender_id, entries);
if let Some(collaborator_user_ids) = collaborator_user_ids {
@@ -320,13 +316,12 @@ impl Server {
}
async fn unshare_worktree(
- self: Arc<Server>,
+ mut self: Arc<Server>,
request: TypedEnvelope<proto::UnshareWorktree>,
) -> tide::Result<()> {
let worktree_id = request.payload.worktree_id;
let worktree = self
- .store
- .write()
+ .state_mut()
.await
.unshare_worktree(worktree_id, request.sender_id)?;
@@ -342,20 +337,16 @@ impl Server {
}
async fn join_worktree(
- self: Arc<Server>,
+ mut self: Arc<Server>,
request: TypedEnvelope<proto::JoinWorktree>,
) -> tide::Result<()> {
let worktree_id = request.payload.worktree_id;
let user_id = self
- .store
- .read()
+ .state()
.await
.user_id_for_connection(request.sender_id)?;
- let response;
- let connection_ids;
- let collaborator_user_ids;
- let mut state = self.store.write().await;
+ let mut state = self.state_mut().await;
match state.join_worktree(request.sender_id, user_id, worktree_id) {
Ok(JoinedWorktree {
replica_id,
@@ -376,7 +367,7 @@ impl Server {
});
}
}
- response = proto::JoinWorktreeResponse {
+ let response = proto::JoinWorktreeResponse {
worktree: Some(proto::Worktree {
id: worktree_id,
root_name: worktree.root_name.clone(),
@@ -385,10 +376,29 @@ impl Server {
replica_id: replica_id as u32,
peers,
};
- connection_ids = worktree.connection_ids();
- collaborator_user_ids = worktree.collaborator_user_ids.clone();
+ let connection_ids = worktree.connection_ids();
+ let collaborator_user_ids = worktree.collaborator_user_ids.clone();
+ drop(state);
+
+ broadcast(request.sender_id, connection_ids, |conn_id| {
+ self.peer.send(
+ conn_id,
+ proto::AddPeer {
+ worktree_id,
+ peer: Some(proto::Peer {
+ peer_id: request.sender_id.0,
+ replica_id: response.replica_id,
+ }),
+ },
+ )
+ })
+ .await?;
+ self.peer.respond(request.receipt(), response).await?;
+ self.update_collaborators_for_users(&collaborator_user_ids)
+ .await?;
}
Err(error) => {
+ drop(state);
self.peer
.respond_with_error(
request.receipt(),
@@ -397,44 +407,23 @@ impl Server {
},
)
.await?;
- return Ok(());
}
}
- drop(state);
- broadcast(request.sender_id, connection_ids, |conn_id| {
- self.peer.send(
- conn_id,
- proto::AddPeer {
- worktree_id,
- peer: Some(proto::Peer {
- peer_id: request.sender_id.0,
- replica_id: response.replica_id,
- }),
- },
- )
- })
- .await?;
- self.peer.respond(request.receipt(), response).await?;
- self.update_collaborators_for_users(&collaborator_user_ids)
- .await?;
-
Ok(())
}
async fn leave_worktree(
- self: Arc<Server>,
+ mut self: Arc<Server>,
request: TypedEnvelope<proto::LeaveWorktree>,
) -> tide::Result<()> {
let sender_id = request.sender_id;
let worktree_id = request.payload.worktree_id;
-
- if let Some(worktree) = self
- .store
- .write()
+ let worktree = self
+ .state_mut()
.await
- .leave_worktree(sender_id, worktree_id)
- {
+ .leave_worktree(sender_id, worktree_id);
+ if let Some(worktree) = worktree {
broadcast(sender_id, worktree.connection_ids, |conn_id| {
self.peer.send(
conn_id,
@@ -452,10 +441,10 @@ impl Server {
}
async fn update_worktree(
- self: Arc<Server>,
+ mut self: Arc<Server>,
request: TypedEnvelope<proto::UpdateWorktree>,
) -> tide::Result<()> {
- let connection_ids = self.store.write().await.update_worktree(
+ let connection_ids = self.state_mut().await.update_worktree(
request.sender_id,
request.payload.worktree_id,
&request.payload.removed_entries,
@@ -477,8 +466,7 @@ impl Server {
) -> tide::Result<()> {
let receipt = request.receipt();
let host_connection_id = self
- .store
- .read()
+ .state()
.await
.worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
let response = self
@@ -494,8 +482,7 @@ impl Server {
request: TypedEnvelope<proto::CloseBuffer>,
) -> tide::Result<()> {
let host_connection_id = self
- .store
- .read()
+ .state()
.await
.worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
self.peer
@@ -511,7 +498,7 @@ impl Server {
let host;
let guests;
{
- let state = self.store.read().await;
+ let state = self.state().await;
host = state
.worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
guests = state
@@ -547,8 +534,7 @@ impl Server {
) -> tide::Result<()> {
broadcast(
request.sender_id,
- self.store
- .read()
+ self.state()
.await
.worktree_connection_ids(request.sender_id, request.payload.worktree_id)?,
|connection_id| {
@@ -585,8 +571,7 @@ impl Server {
request: TypedEnvelope<proto::GetChannels>,
) -> tide::Result<()> {
let user_id = self
- .store
- .read()
+ .state()
.await
.user_id_for_connection(request.sender_id)?;
let channels = self.app_state.db.get_accessible_channels(user_id).await?;
@@ -637,7 +622,7 @@ impl Server {
) -> tide::Result<()> {
let mut send_futures = Vec::new();
- let state = self.store.read().await;
+ let state = self.state().await;
for user_id in user_ids {
let collaborators = state.collaborators_for_user(*user_id);
for connection_id in state.connection_ids_for_user(*user_id) {
@@ -657,12 +642,11 @@ impl Server {
}
async fn join_channel(
- self: Arc<Self>,
+ mut self: Arc<Self>,
request: TypedEnvelope<proto::JoinChannel>,
) -> tide::Result<()> {
let user_id = self
- .store
- .read()
+ .state()
.await
.user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
@@ -675,8 +659,7 @@ impl Server {
Err(anyhow!("access denied"))?;
}
- self.store
- .write()
+ self.state_mut()
.await
.join_channel(request.sender_id, channel_id);
let messages = self
@@ -706,12 +689,11 @@ impl Server {
}
async fn leave_channel(
- self: Arc<Self>,
+ mut self: Arc<Self>,
request: TypedEnvelope<proto::LeaveChannel>,
) -> tide::Result<()> {
let user_id = self
- .store
- .read()
+ .state()
.await
.user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
@@ -724,8 +706,7 @@ impl Server {
Err(anyhow!("access denied"))?;
}
- self.store
- .write()
+ self.state_mut()
.await
.leave_channel(request.sender_id, channel_id);
@@ -741,7 +722,7 @@ impl Server {
let user_id;
let connection_ids;
{
- let state = self.store.read().await;
+ let state = self.state().await;
user_id = state.user_id_for_connection(request.sender_id)?;
if let Some(ids) = state.channel_connection_ids(channel_id) {
connection_ids = ids;
@@ -829,8 +810,7 @@ impl Server {
request: TypedEnvelope<proto::GetChannelMessages>,
) -> tide::Result<()> {
let user_id = self
- .store
- .read()
+ .state()
.await
.user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
@@ -872,6 +852,18 @@ impl Server {
.await?;
Ok(())
}
+
+ fn state<'a>(
+ self: &'a Arc<Self>,
+ ) -> impl Future<Output = async_std::sync::RwLockReadGuard<'a, Store>> {
+ self.store.read()
+ }
+
+ fn state_mut<'a>(
+ self: &'a mut Arc<Self>,
+ ) -> impl Future<Output = async_std::sync::RwLockWriteGuard<'a, Store>> {
+ self.store.write()
+ }
}
pub async fn broadcast<F, T>(