Detailed changes
@@ -169,6 +169,30 @@ impl Database {
self.run(body).await
}
+ pub async fn weak_transaction<F, Fut, T>(&self, f: F) -> Result<T>
+ where
+ F: Send + Fn(TransactionHandle) -> Fut,
+ Fut: Send + Future<Output = Result<T>>,
+ {
+ let body = async {
+ let (tx, result) = self.with_weak_transaction(&f).await?;
+ match result {
+ Ok(result) => match tx.commit().await.map_err(Into::into) {
+ Ok(()) => return Ok(result),
+ Err(error) => {
+ return Err(error);
+ }
+ },
+ Err(error) => {
+ tx.rollback().await?;
+ return Err(error);
+ }
+ }
+ };
+
+ self.run(body).await
+ }
+
/// The same as room_transaction, but if you need to only optionally return a Room.
async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
where
@@ -284,6 +308,30 @@ impl Database {
Ok((tx, result))
}
+ async fn with_weak_transaction<F, Fut, T>(
+ &self,
+ f: &F,
+ ) -> Result<(DatabaseTransaction, Result<T>)>
+ where
+ F: Send + Fn(TransactionHandle) -> Fut,
+ Fut: Send + Future<Output = Result<T>>,
+ {
+ let tx = self
+ .pool
+ .begin_with_config(Some(IsolationLevel::ReadCommitted), None)
+ .await?;
+
+ let mut tx = Arc::new(Some(tx));
+ let result = f(TransactionHandle(tx.clone())).await;
+ let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
+ return Err(anyhow!(
+ "couldn't complete transaction because it's still in use"
+ ))?;
+ };
+
+ Ok((tx, result))
+ }
+
async fn run<F, T>(&self, future: F) -> Result<T>
where
F: Future<Output = Result<T>>,
@@ -457,9 +505,8 @@ pub struct NewUserResult {
/// The result of moving a channel.
#[derive(Debug)]
pub struct MoveChannelResult {
- pub participants_to_update: HashMap<UserId, ChannelsForUser>,
- pub participants_to_remove: HashSet<UserId>,
- pub moved_channels: HashSet<ChannelId>,
+ pub previous_participants: Vec<ChannelMember>,
+ pub descendent_ids: Vec<ChannelId>,
}
/// The result of renaming a channel.
@@ -22,7 +22,6 @@ impl Database {
Ok(self
.create_channel(name, None, creator_id)
.await?
- .channel
.id)
}
@@ -36,7 +35,6 @@ impl Database {
Ok(self
.create_channel(name, Some(parent), creator_id)
.await?
- .channel
.id)
}
@@ -46,7 +44,7 @@ impl Database {
name: &str,
parent_channel_id: Option<ChannelId>,
admin_id: UserId,
- ) -> Result<CreateChannelResult> {
+ ) -> Result<Channel> {
let name = Self::sanitize_channel_name(name)?;
self.transaction(move |tx| async move {
let mut parent = None;
@@ -72,14 +70,7 @@ impl Database {
.insert(&*tx)
.await?;
- let participants_to_update;
- if let Some(parent) = &parent {
- participants_to_update = self
- .participants_to_notify_for_channel_change(parent, &*tx)
- .await?;
- } else {
- participants_to_update = vec![];
-
+ if parent.is_none() {
channel_member::ActiveModel {
id: ActiveValue::NotSet,
channel_id: ActiveValue::Set(channel.id),
@@ -89,12 +80,9 @@ impl Database {
}
.insert(&*tx)
.await?;
- };
+ }
- Ok(CreateChannelResult {
- channel: Channel::from_model(channel, ChannelRole::Admin),
- participants_to_update,
- })
+ Ok(Channel::from_model(channel, ChannelRole::Admin))
})
.await
}
@@ -718,6 +706,19 @@ impl Database {
})
}
+ pub async fn new_participants_to_notify(
+ &self,
+ parent_channel_id: ChannelId,
+ ) -> Result<Vec<(UserId, ChannelsForUser)>> {
+ self.weak_transaction(|tx| async move {
+ let parent_channel = self.get_channel_internal(parent_channel_id, &*tx).await?;
+ self.participants_to_notify_for_channel_change(&parent_channel, &*tx)
+ .await
+ })
+ .await
+ }
+
+ // TODO: this is very expensive, and we should rethink
async fn participants_to_notify_for_channel_change(
&self,
new_parent: &channel::Model,
@@ -1287,7 +1288,7 @@ impl Database {
let mut model = channel.into_active_model();
model.parent_path = ActiveValue::Set(new_parent_path);
- let channel = model.update(&*tx).await?;
+ model.update(&*tx).await?;
if new_parent_channel.is_none() {
channel_member::ActiveModel {
@@ -1314,34 +1315,9 @@ impl Database {
.all(&*tx)
.await?;
- let participants_to_update: HashMap<_, _> = self
- .participants_to_notify_for_channel_change(
- new_parent_channel.as_ref().unwrap_or(&channel),
- &*tx,
- )
- .await?
- .into_iter()
- .collect();
-
- let mut moved_channels: HashSet<ChannelId> = HashSet::default();
- for id in descendent_ids {
- moved_channels.insert(id);
- }
- moved_channels.insert(channel_id);
-
- let mut participants_to_remove: HashSet<UserId> = HashSet::default();
- for participant in previous_participants {
- if participant.kind == proto::channel_member::Kind::AncestorMember {
- if !participants_to_update.contains_key(&participant.user_id) {
- participants_to_remove.insert(participant.user_id);
- }
- }
- }
-
Ok(Some(MoveChannelResult {
- participants_to_remove,
- participants_to_update,
- moved_channels,
+ previous_participants,
+ descendent_ids,
}))
})
.await
@@ -15,11 +15,11 @@ test_both_dbs!(
async fn test_channel_message_retrieval(db: &Arc<Database>) {
let user = new_test_user(db, "user@example.com").await;
- let result = db.create_channel("channel", None, user).await.unwrap();
+ let channel = db.create_channel("channel", None, user).await.unwrap();
let owner_id = db.create_server("test").await.unwrap().0 as u32;
db.join_channel_chat(
- result.channel.id,
+ channel.id,
rpc::ConnectionId { owner_id, id: 0 },
user,
)
@@ -30,7 +30,7 @@ async fn test_channel_message_retrieval(db: &Arc<Database>) {
for i in 0..10 {
all_messages.push(
db.create_channel_message(
- result.channel.id,
+ channel.id,
user,
&i.to_string(),
&[],
@@ -45,7 +45,7 @@ async fn test_channel_message_retrieval(db: &Arc<Database>) {
}
let messages = db
- .get_channel_messages(result.channel.id, user, 3, None)
+ .get_channel_messages(channel.id, user, 3, None)
.await
.unwrap()
.into_iter()
@@ -55,7 +55,7 @@ async fn test_channel_message_retrieval(db: &Arc<Database>) {
let messages = db
.get_channel_messages(
- result.channel.id,
+ channel.id,
user,
4,
Some(MessageId::from_proto(all_messages[6])),
@@ -370,7 +370,6 @@ async fn test_channel_message_mentions(db: &Arc<Database>) {
.create_channel("channel", None, user_a)
.await
.unwrap()
- .channel
.id;
db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member)
.await
@@ -3,9 +3,9 @@ mod connection_pool;
use crate::{
auth::{self, Impersonator},
db::{
- self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
+ self, BufferId, ChannelId, ChannelRole, ChannelsForUser,
CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
- MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
+ NotificationId, ProjectId, RemoveChannelMemberResult,
RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
User, UserId,
},
@@ -2301,10 +2301,7 @@ async fn create_channel(
let db = session.db().await;
let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
- let CreateChannelResult {
- channel,
- participants_to_update,
- } = db
+ let channel = db
.create_channel(&request.name, parent_id, session.user_id)
.await?;
@@ -2313,6 +2310,13 @@ async fn create_channel(
parent_id: request.parent_id,
})?;
+ let participants_to_update;
+ if let Some(parent) = parent_id {
+ participants_to_update = db.new_participants_to_notify(parent).await?;
+ } else {
+ participants_to_update = vec![];
+ }
+
let connection_pool = session.connection_pool().await;
for (user_id, channels) in participants_to_update {
let update = build_channels_update(channels, vec![]);
@@ -2566,47 +2570,55 @@ async fn move_channel(
let channel_id = ChannelId::from_proto(request.channel_id);
let to = request.to.map(ChannelId::from_proto);
- let result = session
- .db()
- .await
- .move_channel(channel_id, to, session.user_id)
- .await?;
+ let result = session.db().await.move_channel(channel_id, to, session.user_id).await?;
- notify_channel_moved(result, session).await?;
+ if let Some(result) = result {
+ let participants_to_update: HashMap<_, _> = session.db().await
+ .new_participants_to_notify(
+ to.unwrap_or(channel_id)
+ )
+ .await?
+ .into_iter()
+ .collect();
- response.send(Ack {})?;
- Ok(())
-}
+ let mut moved_channels: HashSet<ChannelId> = HashSet::default();
+ for id in result.descendent_ids {
+ moved_channels.insert(id);
+ }
+ moved_channels.insert(channel_id);
-async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
- let Some(MoveChannelResult {
- participants_to_remove,
- participants_to_update,
- moved_channels,
- }) = result
- else {
- return Ok(());
- };
- let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
+ let mut participants_to_remove: HashSet<UserId> = HashSet::default();
+ for participant in result.previous_participants {
+ if participant.kind == proto::channel_member::Kind::AncestorMember {
+ if !participants_to_update.contains_key(&participant.user_id) {
+ participants_to_remove.insert(participant.user_id);
+ }
+ }
+ }
- let connection_pool = session.connection_pool().await;
- for (user_id, channels) in participants_to_update {
- let mut update = build_channels_update(channels, vec![]);
- update.delete_channels = moved_channels.clone();
- for connection_id in connection_pool.user_connection_ids(user_id) {
- session.peer.send(connection_id, update.clone())?;
- }
- }
+ let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
- for user_id in participants_to_remove {
- let update = proto::UpdateChannels {
- delete_channels: moved_channels.clone(),
- ..Default::default()
- };
- for connection_id in connection_pool.user_connection_ids(user_id) {
- session.peer.send(connection_id, update.clone())?;
+ let connection_pool = session.connection_pool().await;
+ for (user_id, channels) in participants_to_update {
+ let mut update = build_channels_update(channels, vec![]);
+ update.delete_channels = moved_channels.clone();
+ for connection_id in connection_pool.user_connection_ids(user_id) {
+ session.peer.send(connection_id, update.clone())?;
+ }
+ }
+
+ for user_id in participants_to_remove {
+ let update = proto::UpdateChannels {
+ delete_channels: moved_channels.clone(),
+ ..Default::default()
+ };
+ for connection_id in connection_pool.user_connection_ids(user_id) {
+ session.peer.send(connection_id, update.clone())?;
+ }
+ }
}
- }
+
+ response.send(Ack {})?;
Ok(())
}