Detailed changes
@@ -88,80 +88,87 @@ impl Database {
.await
}
- pub async fn join_channel_internal(
+ pub async fn join_channel(
&self,
channel_id: ChannelId,
user_id: UserId,
connection: ConnectionId,
environment: &str,
- tx: &DatabaseTransaction,
- ) -> Result<(JoinRoom, bool)> {
- let mut joined = false;
+ ) -> Result<(JoinRoom, Option<ChannelId>)> {
+ self.transaction(move |tx| async move {
+ let mut joined_channel_id = None;
- let channel = channel::Entity::find()
- .filter(channel::Column::Id.eq(channel_id))
- .one(&*tx)
- .await?;
+ let channel = channel::Entity::find()
+ .filter(channel::Column::Id.eq(channel_id))
+ .one(&*tx)
+ .await?;
- let mut role = self
- .channel_role_for_user(channel_id, user_id, &*tx)
- .await?;
+ let mut role = self
+ .channel_role_for_user(channel_id, user_id, &*tx)
+ .await?;
+
+ if role.is_none() && channel.is_some() {
+ if let Some(invitation) = self
+ .pending_invite_for_channel(channel_id, user_id, &*tx)
+ .await?
+ {
+ // note, this may be a parent channel
+ joined_channel_id = Some(invitation.channel_id);
+ role = Some(invitation.role);
+
+ channel_member::Entity::update(channel_member::ActiveModel {
+ accepted: ActiveValue::Set(true),
+ ..invitation.into_active_model()
+ })
+ .exec(&*tx)
+ .await?;
+
+ debug_assert!(
+ self.channel_role_for_user(channel_id, user_id, &*tx)
+ .await?
+ == role
+ );
+ }
+ }
+ if role.is_none()
+ && channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public)
+ {
+ let channel_id_to_join = self
+ .most_public_ancestor_for_channel(channel_id, &*tx)
+ .await?
+ .unwrap_or(channel_id);
+ role = Some(ChannelRole::Guest);
+ joined_channel_id = Some(channel_id_to_join);
- if role.is_none() {
- if channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public) {
channel_member::Entity::insert(channel_member::ActiveModel {
id: ActiveValue::NotSet,
- channel_id: ActiveValue::Set(channel_id),
+ channel_id: ActiveValue::Set(channel_id_to_join),
user_id: ActiveValue::Set(user_id),
accepted: ActiveValue::Set(true),
role: ActiveValue::Set(ChannelRole::Guest),
})
- .on_conflict(
- OnConflict::columns([
- channel_member::Column::UserId,
- channel_member::Column::ChannelId,
- ])
- .update_columns([channel_member::Column::Accepted])
- .to_owned(),
- )
.exec(&*tx)
.await?;
debug_assert!(
self.channel_role_for_user(channel_id, user_id, &*tx)
.await?
- == Some(ChannelRole::Guest)
+ == role
);
-
- role = Some(ChannelRole::Guest);
- joined = true;
}
- }
-
- if channel.is_none() || role.is_none() || role == Some(ChannelRole::Banned) {
- Err(anyhow!("no such channel, or not allowed"))?
- }
- let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
- let room_id = self
- .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx)
- .await?;
+ if channel.is_none() || role.is_none() || role == Some(ChannelRole::Banned) {
+ Err(anyhow!("no such channel, or not allowed"))?
+ }
- self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx)
- .await
- .map(|jr| (jr, joined))
- }
+ let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
+ let room_id = self
+ .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx)
+ .await?;
- pub async fn join_channel(
- &self,
- channel_id: ChannelId,
- user_id: UserId,
- connection: ConnectionId,
- environment: &str,
- ) -> Result<(JoinRoom, bool)> {
- self.transaction(move |tx| async move {
- self.join_channel_internal(channel_id, user_id, connection, environment, &*tx)
+ self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx)
.await
+ .map(|jr| (jr, joined_channel_id))
})
.await
}
@@ -624,29 +631,29 @@ impl Database {
admin_id: UserId,
for_user: UserId,
role: ChannelRole,
- ) -> Result<()> {
+ ) -> Result<channel_member::Model> {
self.transaction(|tx| async move {
self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
.await?;
- let result = channel_member::Entity::update_many()
+ let membership = channel_member::Entity::find()
.filter(
channel_member::Column::ChannelId
.eq(channel_id)
.and(channel_member::Column::UserId.eq(for_user)),
)
- .set(channel_member::ActiveModel {
- role: ActiveValue::set(role),
- ..Default::default()
- })
- .exec(&*tx)
+ .one(&*tx)
.await?;
- if result.rows_affected == 0 {
- Err(anyhow!("no such member"))?;
- }
+ let Some(membership) = membership else {
+ Err(anyhow!("no such member"))?
+ };
- Ok(())
+ let mut update = membership.into_active_model();
+ update.role = ActiveValue::Set(role);
+ let updated = channel_member::Entity::update(update).exec(&*tx).await?;
+
+ Ok(updated)
})
.await
}
@@ -844,6 +851,52 @@ impl Database {
}
}
+ pub async fn pending_invite_for_channel(
+ &self,
+ channel_id: ChannelId,
+ user_id: UserId,
+ tx: &DatabaseTransaction,
+ ) -> Result<Option<channel_member::Model>> {
+ let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
+
+ let row = channel_member::Entity::find()
+ .filter(channel_member::Column::ChannelId.is_in(channel_ids))
+ .filter(channel_member::Column::UserId.eq(user_id))
+ .filter(channel_member::Column::Accepted.eq(false))
+ .one(&*tx)
+ .await?;
+
+ Ok(row)
+ }
+
+ pub async fn most_public_ancestor_for_channel(
+ &self,
+ channel_id: ChannelId,
+ tx: &DatabaseTransaction,
+ ) -> Result<Option<ChannelId>> {
+ let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
+
+ let rows = channel::Entity::find()
+ .filter(channel::Column::Id.is_in(channel_ids.clone()))
+ .filter(channel::Column::Visibility.eq(ChannelVisibility::Public))
+ .all(&*tx)
+ .await?;
+
+ let mut visible_channels: HashSet<ChannelId> = HashSet::default();
+
+ for row in rows {
+ visible_channels.insert(row.id);
+ }
+
+ for ancestor in channel_ids.into_iter().rev() {
+ if visible_channels.contains(&ancestor) {
+ return Ok(Some(ancestor));
+ }
+ }
+
+ Ok(None)
+ }
+
pub async fn channel_role_for_user(
&self,
channel_id: ChannelId,
@@ -864,7 +917,8 @@ impl Database {
.filter(
channel_member::Column::ChannelId
.is_in(channel_ids)
- .and(channel_member::Column::UserId.eq(user_id)),
+ .and(channel_member::Column::UserId.eq(user_id))
+ .and(channel_member::Column::Accepted.eq(true)),
)
.select_only()
.column(channel_member::Column::ChannelId)
@@ -1009,52 +1063,22 @@ impl Database {
Ok(results)
}
- /// Returns the channel with the given ID and:
- /// - true if the user is a member
- /// - false if the user hasn't accepted the invitation yet
- pub async fn get_channel(
- &self,
- channel_id: ChannelId,
- user_id: UserId,
- ) -> Result<Option<(Channel, bool)>> {
+ /// Returns the channel with the given ID
+ pub async fn get_channel(&self, channel_id: ChannelId, user_id: UserId) -> Result<Channel> {
self.transaction(|tx| async move {
- let tx = tx;
+ self.check_user_is_channel_participant(channel_id, user_id, &*tx)
+ .await?;
let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
+ let Some(channel) = channel else {
+ Err(anyhow!("no such channel"))?
+ };
- if let Some(channel) = channel {
- if self
- .check_user_is_channel_member(channel_id, user_id, &*tx)
- .await
- .is_err()
- {
- return Ok(None);
- }
-
- let channel_membership = channel_member::Entity::find()
- .filter(
- channel_member::Column::ChannelId
- .eq(channel_id)
- .and(channel_member::Column::UserId.eq(user_id)),
- )
- .one(&*tx)
- .await?;
-
- let is_accepted = channel_membership
- .map(|membership| membership.accepted)
- .unwrap_or(false);
-
- Ok(Some((
- Channel {
- id: channel.id,
- visibility: channel.visibility,
- name: channel.name,
- },
- is_accepted,
- )))
- } else {
- Ok(None)
- }
+ Ok(Channel {
+ id: channel.id,
+ visibility: channel.visibility,
+ name: channel.name,
+ })
})
.await
}
@@ -51,7 +51,7 @@ async fn test_channels(db: &Arc<Database>) {
let zed_id = db.create_root_channel("zed", a_id).await.unwrap();
// Make sure that people cannot read channels they haven't been invited to
- assert!(db.get_channel(zed_id, b_id).await.unwrap().is_none());
+ assert!(db.get_channel(zed_id, b_id).await.is_err());
db.invite_channel_member(zed_id, b_id, a_id, ChannelRole::Member)
.await
@@ -157,7 +157,7 @@ async fn test_channels(db: &Arc<Database>) {
// Remove a single channel
db.delete_channel(crdb_id, a_id).await.unwrap();
- assert!(db.get_channel(crdb_id, a_id).await.unwrap().is_none());
+ assert!(db.get_channel(crdb_id, a_id).await.is_err());
// Remove a channel tree
let (mut channel_ids, user_ids) = db.delete_channel(rust_id, a_id).await.unwrap();
@@ -165,9 +165,9 @@ async fn test_channels(db: &Arc<Database>) {
assert_eq!(channel_ids, &[rust_id, cargo_id, cargo_ra_id]);
assert_eq!(user_ids, &[a_id]);
- assert!(db.get_channel(rust_id, a_id).await.unwrap().is_none());
- assert!(db.get_channel(cargo_id, a_id).await.unwrap().is_none());
- assert!(db.get_channel(cargo_ra_id, a_id).await.unwrap().is_none());
+ assert!(db.get_channel(rust_id, a_id).await.is_err());
+ assert!(db.get_channel(cargo_id, a_id).await.is_err());
+ assert!(db.get_channel(cargo_ra_id, a_id).await.is_err());
}
test_both_dbs!(
@@ -381,11 +381,7 @@ async fn test_channel_renames(db: &Arc<Database>) {
let zed_archive_id = zed_id;
- let (channel, _) = db
- .get_channel(zed_archive_id, user_1)
- .await
- .unwrap()
- .unwrap();
+ let channel = db.get_channel(zed_archive_id, user_1).await.unwrap();
assert_eq!(channel.name, "zed-archive");
let non_permissioned_rename = db
@@ -860,12 +856,6 @@ async fn test_user_is_channel_participant(db: &Arc<Database>) {
})
.await
.unwrap();
- db.transaction(|tx| async move {
- db.check_user_is_channel_participant(vim_channel, guest, &*tx)
- .await
- })
- .await
- .unwrap();
let members = db
.get_channel_participant_details(vim_channel, admin)
@@ -896,6 +886,13 @@ async fn test_user_is_channel_participant(db: &Arc<Database>) {
.await
.unwrap();
+ db.transaction(|tx| async move {
+ db.check_user_is_channel_participant(vim_channel, guest, &*tx)
+ .await
+ })
+ .await
+ .unwrap();
+
let channels = db.get_channels_for_user(guest).await.unwrap().channels;
assert_dag(channels, &[(vim_channel, None)]);
let channels = db.get_channels_for_user(member).await.unwrap().channels;
@@ -953,29 +950,7 @@ async fn test_user_is_channel_participant(db: &Arc<Database>) {
.await
.unwrap();
- db.transaction(|tx| async move {
- db.check_user_is_channel_participant(zed_channel, guest, &*tx)
- .await
- })
- .await
- .unwrap();
- assert!(db
- .transaction(|tx| async move {
- db.check_user_is_channel_participant(active_channel, guest, &*tx)
- .await
- })
- .await
- .is_err(),);
-
- db.transaction(|tx| async move {
- db.check_user_is_channel_participant(vim_channel, guest, &*tx)
- .await
- })
- .await
- .unwrap();
-
// currently people invited to parent channels are not shown here
- // (though they *do* have permissions!)
let members = db
.get_channel_participant_details(vim_channel, admin)
.await
@@ -1000,6 +975,27 @@ async fn test_user_is_channel_participant(db: &Arc<Database>) {
.await
.unwrap();
+ db.transaction(|tx| async move {
+ db.check_user_is_channel_participant(zed_channel, guest, &*tx)
+ .await
+ })
+ .await
+ .unwrap();
+ assert!(db
+ .transaction(|tx| async move {
+ db.check_user_is_channel_participant(active_channel, guest, &*tx)
+ .await
+ })
+ .await
+ .is_err(),);
+
+ db.transaction(|tx| async move {
+ db.check_user_is_channel_participant(vim_channel, guest, &*tx)
+ .await
+ })
+ .await
+ .unwrap();
+
let members = db
.get_channel_participant_details(vim_channel, admin)
.await
@@ -38,7 +38,7 @@ use lazy_static::lazy_static;
use prometheus::{register_int_gauge, IntGauge};
use rpc::{
proto::{
- self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage, JoinRoom,
+ self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
},
Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
@@ -2289,10 +2289,7 @@ async fn invite_channel_member(
)
.await?;
- let (channel, _) = db
- .get_channel(channel_id, session.user_id)
- .await?
- .ok_or_else(|| anyhow!("channel not found"))?;
+ let channel = db.get_channel(channel_id, session.user_id).await?;
let mut update = proto::UpdateChannels::default();
update.channel_invitations.push(proto::Channel {
@@ -2380,21 +2377,19 @@ async fn set_channel_member_role(
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
let member_id = UserId::from_proto(request.user_id);
- db.set_channel_member_role(
- channel_id,
- session.user_id,
- member_id,
- request.role().into(),
- )
- .await?;
+ let channel_member = db
+ .set_channel_member_role(
+ channel_id,
+ session.user_id,
+ member_id,
+ request.role().into(),
+ )
+ .await?;
- let (channel, has_accepted) = db
- .get_channel(channel_id, member_id)
- .await?
- .ok_or_else(|| anyhow!("channel not found"))?;
+ let channel = db.get_channel(channel_id, session.user_id).await?;
let mut update = proto::UpdateChannels::default();
- if has_accepted {
+ if channel_member.accepted {
update.channel_permissions.push(proto::ChannelPermission {
channel_id: channel.id.to_proto(),
role: request.role,
@@ -2724,9 +2719,11 @@ async fn join_channel_internal(
channel_id: joined_room.channel_id.map(|id| id.to_proto()),
live_kit_connection_info,
})?;
+ dbg!("Joined channel", &joined_channel);
- if joined_channel {
- channel_membership_updated(db, channel_id, &session).await?
+ if let Some(joined_channel) = joined_channel {
+ dbg!("CMU");
+ channel_membership_updated(db, joined_channel, &session).await?
}
room_updated(&joined_room.room, &session.peer);
@@ -7,7 +7,7 @@ use channel::{ChannelId, ChannelMembership, ChannelStore};
use client::User;
use gpui::{executor::Deterministic, ModelHandle, TestAppContext};
use rpc::{
- proto::{self},
+ proto::{self, ChannelRole},
RECEIVE_TIMEOUT,
};
use std::sync::Arc;
@@ -965,6 +965,67 @@ async fn test_guest_access(
})
}
+#[gpui::test]
+async fn test_invite_access(
+ deterministic: Arc<Deterministic>,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+) {
+ deterministic.forbid_parking();
+
+ let mut server = TestServer::start(&deterministic).await;
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+
+ let channels = server
+ .make_channel_tree(
+ &[("channel-a", None), ("channel-b", Some("channel-a"))],
+ (&client_a, cx_a),
+ )
+ .await;
+ let channel_a_id = channels[0];
+ let channel_b_id = channels[0];
+
+ let active_call_b = cx_b.read(ActiveCall::global);
+
+ // should not be allowed to join
+ assert!(active_call_b
+ .update(cx_b, |call, cx| call.join_channel(channel_b_id, cx))
+ .await
+ .is_err());
+
+ client_a
+ .channel_store()
+ .update(cx_a, |channel_store, cx| {
+ channel_store.invite_member(
+ channel_a_id,
+ client_b.user_id().unwrap(),
+ ChannelRole::Member,
+ cx,
+ )
+ })
+ .await
+ .unwrap();
+
+ active_call_b
+ .update(cx_b, |call, cx| call.join_channel(channel_b_id, cx))
+ .await
+ .unwrap();
+
+ deterministic.run_until_parked();
+
+ client_b.channel_store().update(cx_b, |channel_store, _| {
+ assert!(channel_store.channel_for_id(channel_b_id).is_some());
+ assert!(channel_store.channel_for_id(channel_a_id).is_some());
+ });
+
+ client_a.channel_store().update(cx_a, |channel_store, _| {
+ let participants = channel_store.channel_participants(channel_b_id);
+ assert_eq!(participants.len(), 1);
+ assert_eq!(participants[0].id, client_b.user_id().unwrap());
+ })
+}
+
#[gpui::test]
async fn test_channel_moving(
deterministic: Arc<Deterministic>,