Detailed changes
@@ -972,6 +972,7 @@ impl ChannelStore {
let mut all_user_ids = Vec::new();
let channel_participants = payload.channel_participants;
+ dbg!(&channel_participants);
for entry in &channel_participants {
for user_id in entry.participant_user_ids.iter() {
if let Err(ix) = all_user_ids.binary_search(user_id) {
@@ -88,6 +88,84 @@ impl Database {
.await
}
+ pub async fn join_channel_internal(
+ &self,
+ channel_id: ChannelId,
+ user_id: UserId,
+ connection: ConnectionId,
+ environment: &str,
+ tx: &DatabaseTransaction,
+ ) -> Result<(JoinRoom, bool)> {
+ let mut joined = false;
+
+ let channel = channel::Entity::find()
+ .filter(channel::Column::Id.eq(channel_id))
+ .one(&*tx)
+ .await?;
+
+ let mut role = self
+ .channel_role_for_user(channel_id, user_id, &*tx)
+ .await?;
+
+ if role.is_none() {
+ if channel.as_ref().map(|c| c.visibility) == Some(ChannelVisibility::Public) {
+ channel_member::Entity::insert(channel_member::ActiveModel {
+ id: ActiveValue::NotSet,
+ channel_id: ActiveValue::Set(channel_id),
+ user_id: ActiveValue::Set(user_id),
+ accepted: ActiveValue::Set(true),
+ role: ActiveValue::Set(ChannelRole::Guest),
+ })
+ .on_conflict(
+ OnConflict::columns([
+ channel_member::Column::UserId,
+ channel_member::Column::ChannelId,
+ ])
+ .update_columns([channel_member::Column::Accepted])
+ .to_owned(),
+ )
+ .exec(&*tx)
+ .await?;
+
+ debug_assert!(
+ self.channel_role_for_user(channel_id, user_id, &*tx)
+ .await?
+ == Some(ChannelRole::Guest)
+ );
+
+ role = Some(ChannelRole::Guest);
+ joined = true;
+ }
+ }
+
+ if channel.is_none() || role.is_none() || role == Some(ChannelRole::Banned) {
+ Err(anyhow!("no such channel, or not allowed"))?
+ }
+
+ let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
+ let room_id = self
+ .get_or_create_channel_room(channel_id, &live_kit_room, environment, &*tx)
+ .await?;
+
+ self.join_channel_room_internal(channel_id, room_id, user_id, connection, &*tx)
+ .await
+ .map(|jr| (jr, joined))
+ }
+
+ pub async fn join_channel(
+ &self,
+ channel_id: ChannelId,
+ user_id: UserId,
+ connection: ConnectionId,
+ environment: &str,
+ ) -> Result<(JoinRoom, bool)> {
+ self.transaction(move |tx| async move {
+ self.join_channel_internal(channel_id, user_id, connection, environment, &*tx)
+ .await
+ })
+ .await
+ }
+
pub async fn set_channel_visibility(
&self,
channel_id: ChannelId,
@@ -981,38 +1059,39 @@ impl Database {
.await
}
- pub async fn get_or_create_channel_room(
+ pub(crate) async fn get_or_create_channel_room(
&self,
channel_id: ChannelId,
live_kit_room: &str,
- enviroment: &str,
+ environment: &str,
+ tx: &DatabaseTransaction,
) -> Result<RoomId> {
- self.transaction(|tx| async move {
- let tx = tx;
-
- let room = room::Entity::find()
- .filter(room::Column::ChannelId.eq(channel_id))
- .one(&*tx)
- .await?;
+ let room = room::Entity::find()
+ .filter(room::Column::ChannelId.eq(channel_id))
+ .one(&*tx)
+ .await?;
- let room_id = if let Some(room) = room {
- room.id
- } else {
- let result = room::Entity::insert(room::ActiveModel {
- channel_id: ActiveValue::Set(Some(channel_id)),
- live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
- enviroment: ActiveValue::Set(Some(enviroment.to_string())),
- ..Default::default()
- })
- .exec(&*tx)
- .await?;
+ let room_id = if let Some(room) = room {
+ if let Some(env) = room.enviroment {
+ if &env != environment {
+ Err(anyhow!("must join using the {} release", env))?;
+ }
+ }
+ room.id
+ } else {
+ let result = room::Entity::insert(room::ActiveModel {
+ channel_id: ActiveValue::Set(Some(channel_id)),
+ live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
+ enviroment: ActiveValue::Set(Some(environment.to_string())),
+ ..Default::default()
+ })
+ .exec(&*tx)
+ .await?;
- result.last_insert_id
- };
+ result.last_insert_id
+ };
- Ok(room_id)
- })
- .await
+ Ok(room_id)
}
// Insert an edge from the given channel to the given other channel.
@@ -300,99 +300,139 @@ impl Database {
}
}
- #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
- enum QueryParticipantIndices {
- ParticipantIndex,
- }
- let existing_participant_indices: Vec<i32> = room_participant::Entity::find()
- .filter(
- room_participant::Column::RoomId
- .eq(room_id)
- .and(room_participant::Column::ParticipantIndex.is_not_null()),
- )
- .select_only()
- .column(room_participant::Column::ParticipantIndex)
- .into_values::<_, QueryParticipantIndices>()
- .all(&*tx)
- .await?;
-
- let mut participant_index = 0;
- while existing_participant_indices.contains(&participant_index) {
- participant_index += 1;
+ if channel_id.is_some() {
+ Err(anyhow!("tried to join channel call directly"))?
}
- if let Some(channel_id) = channel_id {
- self.check_user_is_channel_member(channel_id, user_id, &*tx)
- .await?;
+ let participant_index = self
+ .get_next_participant_index_internal(room_id, &*tx)
+ .await?;
- room_participant::Entity::insert_many([room_participant::ActiveModel {
- room_id: ActiveValue::set(room_id),
- user_id: ActiveValue::set(user_id),
+ let result = room_participant::Entity::update_many()
+ .filter(
+ Condition::all()
+ .add(room_participant::Column::RoomId.eq(room_id))
+ .add(room_participant::Column::UserId.eq(user_id))
+ .add(room_participant::Column::AnsweringConnectionId.is_null()),
+ )
+ .set(room_participant::ActiveModel {
+ participant_index: ActiveValue::Set(Some(participant_index)),
answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
answering_connection_server_id: ActiveValue::set(Some(ServerId(
connection.owner_id as i32,
))),
answering_connection_lost: ActiveValue::set(false),
- calling_user_id: ActiveValue::set(user_id),
- calling_connection_id: ActiveValue::set(connection.id as i32),
- calling_connection_server_id: ActiveValue::set(Some(ServerId(
- connection.owner_id as i32,
- ))),
- participant_index: ActiveValue::Set(Some(participant_index)),
..Default::default()
- }])
- .on_conflict(
- OnConflict::columns([room_participant::Column::UserId])
- .update_columns([
- room_participant::Column::AnsweringConnectionId,
- room_participant::Column::AnsweringConnectionServerId,
- room_participant::Column::AnsweringConnectionLost,
- room_participant::Column::ParticipantIndex,
- ])
- .to_owned(),
- )
+ })
.exec(&*tx)
.await?;
- } else {
- let result = room_participant::Entity::update_many()
- .filter(
- Condition::all()
- .add(room_participant::Column::RoomId.eq(room_id))
- .add(room_participant::Column::UserId.eq(user_id))
- .add(room_participant::Column::AnsweringConnectionId.is_null()),
- )
- .set(room_participant::ActiveModel {
- participant_index: ActiveValue::Set(Some(participant_index)),
- answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
- answering_connection_server_id: ActiveValue::set(Some(ServerId(
- connection.owner_id as i32,
- ))),
- answering_connection_lost: ActiveValue::set(false),
- ..Default::default()
- })
- .exec(&*tx)
- .await?;
- if result.rows_affected == 0 {
- Err(anyhow!("room does not exist or was already joined"))?;
- }
+ if result.rows_affected == 0 {
+ Err(anyhow!("room does not exist or was already joined"))?;
}
let room = self.get_room(room_id, &tx).await?;
- let channel_members = if let Some(channel_id) = channel_id {
- self.get_channel_participants_internal(channel_id, &tx)
- .await?
- } else {
- Vec::new()
- };
Ok(JoinRoom {
room,
- channel_id,
- channel_members,
+ channel_id: None,
+ channel_members: vec![],
})
})
.await
}
+ async fn get_next_participant_index_internal(
+ &self,
+ room_id: RoomId,
+ tx: &DatabaseTransaction,
+ ) -> Result<i32> {
+ #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
+ enum QueryParticipantIndices {
+ ParticipantIndex,
+ }
+ let existing_participant_indices: Vec<i32> = room_participant::Entity::find()
+ .filter(
+ room_participant::Column::RoomId
+ .eq(room_id)
+ .and(room_participant::Column::ParticipantIndex.is_not_null()),
+ )
+ .select_only()
+ .column(room_participant::Column::ParticipantIndex)
+ .into_values::<_, QueryParticipantIndices>()
+ .all(&*tx)
+ .await?;
+
+ let mut participant_index = 0;
+ while existing_participant_indices.contains(&participant_index) {
+ participant_index += 1;
+ }
+
+ Ok(participant_index)
+ }
+
+ pub async fn channel_id_for_room(&self, room_id: RoomId) -> Result<Option<ChannelId>> {
+ self.transaction(|tx| async move {
+ let room: Option<room::Model> = room::Entity::find()
+ .filter(room::Column::Id.eq(room_id))
+ .one(&*tx)
+ .await?;
+
+ Ok(room.and_then(|room| room.channel_id))
+ })
+ .await
+ }
+
+ pub(crate) async fn join_channel_room_internal(
+ &self,
+ channel_id: ChannelId,
+ room_id: RoomId,
+ user_id: UserId,
+ connection: ConnectionId,
+ tx: &DatabaseTransaction,
+ ) -> Result<JoinRoom> {
+ let participant_index = self
+ .get_next_participant_index_internal(room_id, &*tx)
+ .await?;
+
+ room_participant::Entity::insert_many([room_participant::ActiveModel {
+ room_id: ActiveValue::set(room_id),
+ user_id: ActiveValue::set(user_id),
+ answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
+ answering_connection_server_id: ActiveValue::set(Some(ServerId(
+ connection.owner_id as i32,
+ ))),
+ answering_connection_lost: ActiveValue::set(false),
+ calling_user_id: ActiveValue::set(user_id),
+ calling_connection_id: ActiveValue::set(connection.id as i32),
+ calling_connection_server_id: ActiveValue::set(Some(ServerId(
+ connection.owner_id as i32,
+ ))),
+ participant_index: ActiveValue::Set(Some(participant_index)),
+ ..Default::default()
+ }])
+ .on_conflict(
+ OnConflict::columns([room_participant::Column::UserId])
+ .update_columns([
+ room_participant::Column::AnsweringConnectionId,
+ room_participant::Column::AnsweringConnectionServerId,
+ room_participant::Column::AnsweringConnectionLost,
+ room_participant::Column::ParticipantIndex,
+ ])
+ .to_owned(),
+ )
+ .exec(&*tx)
+ .await?;
+
+ let room = self.get_room(room_id, &tx).await?;
+ let channel_members = self
+ .get_channel_participants_internal(channel_id, &tx)
+ .await?;
+ Ok(JoinRoom {
+ room,
+ channel_id: Some(channel_id),
+ channel_members,
+ })
+ }
+
pub async fn rejoin_room(
&self,
rejoin_room: proto::RejoinRoom,
@@ -8,7 +8,7 @@ use crate::{
db::{
queries::channels::ChannelGraph,
tests::{graph, TEST_RELEASE_CHANNEL},
- ChannelId, ChannelRole, Database, NewUserParams, UserId,
+ ChannelId, ChannelRole, Database, NewUserParams, RoomId, UserId,
},
test_both_dbs,
};
@@ -207,15 +207,11 @@ async fn test_joining_channels(db: &Arc<Database>) {
.user_id;
let channel_1 = db.create_root_channel("channel_1", user_1).await.unwrap();
- let room_1 = db
- .get_or_create_channel_room(channel_1, "1", TEST_RELEASE_CHANNEL)
- .await
- .unwrap();
// can join a room with membership to its channel
- let joined_room = db
- .join_room(
- room_1,
+ let (joined_room, _) = db
+ .join_channel(
+ channel_1,
user_1,
ConnectionId { owner_id, id: 1 },
TEST_RELEASE_CHANNEL,
@@ -224,11 +220,12 @@ async fn test_joining_channels(db: &Arc<Database>) {
.unwrap();
assert_eq!(joined_room.room.participants.len(), 1);
+ let room_id = RoomId::from_proto(joined_room.room.id);
drop(joined_room);
// cannot join a room without membership to its channel
assert!(db
.join_room(
- room_1,
+ room_id,
user_2,
ConnectionId { owner_id, id: 1 },
TEST_RELEASE_CHANNEL
@@ -38,7 +38,7 @@ use lazy_static::lazy_static;
use prometheus::{register_int_gauge, IntGauge};
use rpc::{
proto::{
- self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage,
+ self, Ack, AnyTypedEnvelope, ChannelEdge, EntityMessage, EnvelopedMessage, JoinRoom,
LiveKitConnectionInfo, RequestMessage, UpdateChannelBufferCollaborators,
},
Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
@@ -977,6 +977,13 @@ async fn join_room(
session: Session,
) -> Result<()> {
let room_id = RoomId::from_proto(request.id);
+
+ let channel_id = session.db().await.channel_id_for_room(room_id).await?;
+
+ if let Some(channel_id) = channel_id {
+ return join_channel_internal(channel_id, Box::new(response), session).await;
+ }
+
let joined_room = {
let room = session
.db()
@@ -992,16 +999,6 @@ async fn join_room(
room.into_inner()
};
- if let Some(channel_id) = joined_room.channel_id {
- channel_updated(
- channel_id,
- &joined_room.room,
- &joined_room.channel_members,
- &session.peer,
- &*session.connection_pool().await,
- )
- }
-
for connection_id in session
.connection_pool()
.await
@@ -1039,7 +1036,7 @@ async fn join_room(
response.send(proto::JoinRoomResponse {
room: Some(joined_room.room),
- channel_id: joined_room.channel_id.map(|id| id.to_proto()),
+ channel_id: None,
live_kit_connection_info,
})?;
@@ -2602,54 +2599,68 @@ async fn respond_to_channel_invite(
db.respond_to_channel_invite(channel_id, session.user_id, request.accept)
.await?;
+ if request.accept {
+ channel_membership_updated(db, channel_id, &session).await?;
+ } else {
+ let mut update = proto::UpdateChannels::default();
+ update
+ .remove_channel_invitations
+ .push(channel_id.to_proto());
+ session.peer.send(session.connection_id, update)?;
+ }
+ response.send(proto::Ack {})?;
+
+ Ok(())
+}
+
+async fn channel_membership_updated(
+ db: tokio::sync::MutexGuard<'_, DbHandle>,
+ channel_id: ChannelId,
+ session: &Session,
+) -> Result<(), crate::Error> {
let mut update = proto::UpdateChannels::default();
update
.remove_channel_invitations
.push(channel_id.to_proto());
- if request.accept {
- let result = db.get_channel_for_user(channel_id, session.user_id).await?;
- update
+
+ let result = db.get_channel_for_user(channel_id, session.user_id).await?;
+ update.channels.extend(
+ result
.channels
- .extend(
- result
- .channels
- .channels
- .into_iter()
- .map(|channel| proto::Channel {
- id: channel.id.to_proto(),
- visibility: channel.visibility.into(),
- name: channel.name,
- }),
- );
- update.unseen_channel_messages = result.channel_messages;
- update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
- update.insert_edge = result.channels.edges;
- update
- .channel_participants
- .extend(
- result
- .channel_participants
- .into_iter()
- .map(|(channel_id, user_ids)| proto::ChannelParticipants {
- channel_id: channel_id.to_proto(),
- participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
- }),
- );
- update
- .channel_permissions
- .extend(
- result
- .channels_with_admin_privileges
- .into_iter()
- .map(|channel_id| proto::ChannelPermission {
- channel_id: channel_id.to_proto(),
- role: proto::ChannelRole::Admin.into(),
- }),
- );
- }
+ .channels
+ .into_iter()
+ .map(|channel| proto::Channel {
+ id: channel.id.to_proto(),
+ visibility: channel.visibility.into(),
+ name: channel.name,
+ }),
+ );
+ update.unseen_channel_messages = result.channel_messages;
+ update.unseen_channel_buffer_changes = result.unseen_buffer_changes;
+ update.insert_edge = result.channels.edges;
+ update
+ .channel_participants
+ .extend(
+ result
+ .channel_participants
+ .into_iter()
+ .map(|(channel_id, user_ids)| proto::ChannelParticipants {
+ channel_id: channel_id.to_proto(),
+ participant_user_ids: user_ids.into_iter().map(UserId::to_proto).collect(),
+ }),
+ );
+ update
+ .channel_permissions
+ .extend(
+ result
+ .channels_with_admin_privileges
+ .into_iter()
+ .map(|channel_id| proto::ChannelPermission {
+ channel_id: channel_id.to_proto(),
+ role: proto::ChannelRole::Admin.into(),
+ }),
+ );
session.peer.send(session.connection_id, update)?;
- response.send(proto::Ack {})?;
-
Ok(())
}
@@ -2659,19 +2670,35 @@ async fn join_channel(
session: Session,
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
- let live_kit_room = format!("channel-{}", nanoid::nanoid!(30));
+ join_channel_internal(channel_id, Box::new(response), session).await
+}
+
+trait JoinChannelInternalResponse {
+ fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
+}
+impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
+ fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
+ Response::<proto::JoinChannel>::send(self, result)
+ }
+}
+impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
+ fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
+ Response::<proto::JoinRoom>::send(self, result)
+ }
+}
+async fn join_channel_internal(
+ channel_id: ChannelId,
+ response: Box<impl JoinChannelInternalResponse>,
+ session: Session,
+) -> Result<()> {
let joined_room = {
leave_room_for_session(&session).await?;
let db = session.db().await;
- let room_id = db
- .get_or_create_channel_room(channel_id, &live_kit_room, &*RELEASE_CHANNEL_NAME)
- .await?;
-
- let joined_room = db
- .join_room(
- room_id,
+ let (joined_room, joined_channel) = db
+ .join_channel(
+ channel_id,
session.user_id,
session.connection_id,
RELEASE_CHANNEL_NAME.as_str(),
@@ -2698,9 +2725,13 @@ async fn join_channel(
live_kit_connection_info,
})?;
+ if joined_channel {
+ channel_membership_updated(db, channel_id, &session).await?
+ }
+
room_updated(&joined_room.room, &session.peer);
- joined_room.into_inner()
+ joined_room
};
channel_updated(
@@ -2712,7 +2743,6 @@ async fn join_channel(
);
update_user_contacts(session.user_id, &session).await?;
-
Ok(())
}
@@ -912,6 +912,58 @@ async fn test_lost_channel_creation(
],
);
}
+#[gpui::test]
+async fn test_guest_access(
+ deterministic: Arc<Deterministic>,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+) {
+ deterministic.forbid_parking();
+
+ let mut server = TestServer::start(&deterministic).await;
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+
+ let channels = server
+ .make_channel_tree(&[("channel-a", None)], (&client_a, cx_a))
+ .await;
+ let channel_a_id = channels[0];
+
+ let active_call_b = cx_b.read(ActiveCall::global);
+
+ // should not be allowed to join
+ assert!(active_call_b
+ .update(cx_b, |call, cx| call.join_channel(channel_a_id, cx))
+ .await
+ .is_err());
+
+ client_a
+ .channel_store()
+ .update(cx_a, |channel_store, cx| {
+ channel_store.set_channel_visibility(channel_a_id, proto::ChannelVisibility::Public, cx)
+ })
+ .await
+ .unwrap();
+
+ active_call_b
+ .update(cx_b, |call, cx| call.join_channel(channel_a_id, cx))
+ .await
+ .unwrap();
+
+ deterministic.run_until_parked();
+
+ assert!(client_b
+ .channel_store()
+ .update(cx_b, |channel_store, _| channel_store
+ .channel_for_id(channel_a_id)
+ .is_some()));
+
+ client_a.channel_store().update(cx_a, |channel_store, _| {
+ let participants = channel_store.channel_participants(channel_a_id);
+ assert_eq!(participants.len(), 1);
+ assert_eq!(participants[0].id, client_b.user_id().unwrap());
+ })
+}
#[gpui::test]
async fn test_channel_moving(
@@ -1,4 +1,4 @@
-use channel::{Channel, ChannelId, ChannelMembership, ChannelStore};
+use channel::{ChannelId, ChannelMembership, ChannelStore};
use client::{
proto::{self, ChannelRole, ChannelVisibility},
User, UserId, UserStore,