crates/channel/src/channel.rs 🔗
@@ -14,4 +14,5 @@ mod channel_store_tests;
pub fn init(client: &Arc<Client>) {
channel_buffer::init(client);
+ channel_chat::init(client);
}
Max Brunsfeld created
crates/channel/src/channel.rs | 1
crates/channel/src/channel_chat.rs | 8
crates/collab/migrations.sqlite/20221109000000_test_schema.sql | 20
crates/collab/migrations/20230907114200_add_channel_messages.sql | 19
crates/collab/src/db/ids.rs | 2
crates/collab/src/db/queries.rs | 1
crates/collab/src/db/queries/messages.rs | 152 ++
crates/collab/src/db/tables.rs | 2
crates/collab/src/db/tables/channel.rs | 8
crates/collab/src/db/tables/channel_chat_participant.rs | 41
crates/collab/src/db/tables/channel_message.rs | 45
crates/collab/src/db/tests.rs | 1
crates/collab/src/db/tests/message_tests.rs | 53
crates/collab/src/rpc.rs | 119 +
crates/collab/src/tests.rs | 1
crates/collab/src/tests/channel_message_tests.rs | 56
16 files changed, 524 insertions(+), 5 deletions(-)
@@ -14,4 +14,5 @@ mod channel_store_tests;
pub fn init(client: &Arc<Client>) {
channel_buffer::init(client);
+ channel_chat::init(client);
}
@@ -57,6 +57,10 @@ pub enum ChannelChatEvent {
},
}
+pub fn init(client: &Arc<Client>) {
+ client.add_model_message_handler(ChannelChat::handle_message_sent);
+}
+
impl Entity for ChannelChat {
type Event = ChannelChatEvent;
@@ -70,10 +74,6 @@ impl Entity for ChannelChat {
}
impl ChannelChat {
- pub fn init(rpc: &Arc<Client>) {
- rpc.add_model_message_handler(Self::handle_message_sent);
- }
-
pub async fn new(
channel: Arc<Channel>,
user_store: ModelHandle<UserStore>,
@@ -192,6 +192,26 @@ CREATE TABLE "channels" (
"created_at" TIMESTAMP NOT NULL DEFAULT now
);
+CREATE TABLE IF NOT EXISTS "channel_chat_participants" (
+ "id" INTEGER PRIMARY KEY AUTOINCREMENT,
+ "user_id" INTEGER NOT NULL REFERENCES users (id),
+ "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
+ "connection_id" INTEGER NOT NULL,
+ "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE
+);
+CREATE INDEX "index_channel_chat_participants_on_channel_id" ON "channel_chat_participants" ("channel_id");
+
+CREATE TABLE IF NOT EXISTS "channel_messages" (
+ "id" INTEGER PRIMARY KEY AUTOINCREMENT,
+ "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
+ "sender_id" INTEGER NOT NULL REFERENCES users (id),
+ "body" TEXT NOT NULL,
+ "sent_at" TIMESTAMP,
+ "nonce" BLOB NOT NULL
+);
+CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id");
+CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce");
+
CREATE TABLE "channel_paths" (
"id_path" TEXT NOT NULL PRIMARY KEY,
"channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE
@@ -0,0 +1,19 @@
+CREATE TABLE IF NOT EXISTS "channel_messages" (
+ "id" SERIAL PRIMARY KEY,
+ "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
+ "sender_id" INTEGER NOT NULL REFERENCES users (id),
+ "body" TEXT NOT NULL,
+ "sent_at" TIMESTAMP,
+ "nonce" UUID NOT NULL
+);
+CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id");
+CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce");
+
+CREATE TABLE IF NOT EXISTS "channel_chat_participants" (
+ "id" SERIAL PRIMARY KEY,
+ "user_id" INTEGER NOT NULL REFERENCES users (id),
+ "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
+ "connection_id" INTEGER NOT NULL,
+ "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE
+);
+CREATE INDEX "index_channel_chat_participants_on_channel_id" ON "channel_chat_participants" ("channel_id");
@@ -112,8 +112,10 @@ fn value_to_integer(v: Value) -> Result<i32, ValueTypeErr> {
id_type!(BufferId);
id_type!(AccessTokenId);
+id_type!(ChannelChatParticipantId);
id_type!(ChannelId);
id_type!(ChannelMemberId);
+id_type!(MessageId);
id_type!(ContactId);
id_type!(FollowerId);
id_type!(RoomId);
@@ -4,6 +4,7 @@ pub mod access_tokens;
pub mod buffers;
pub mod channels;
pub mod contacts;
+pub mod messages;
pub mod projects;
pub mod rooms;
pub mod servers;
@@ -0,0 +1,152 @@
+use super::*;
+use time::OffsetDateTime;
+
+impl Database {
+ pub async fn join_channel_chat(
+ &self,
+ channel_id: ChannelId,
+ connection_id: ConnectionId,
+ user_id: UserId,
+ ) -> Result<()> {
+ self.transaction(|tx| async move {
+ self.check_user_is_channel_member(channel_id, user_id, &*tx)
+ .await?;
+ channel_chat_participant::ActiveModel {
+ id: ActiveValue::NotSet,
+ channel_id: ActiveValue::Set(channel_id),
+ user_id: ActiveValue::Set(user_id),
+ connection_id: ActiveValue::Set(connection_id.id as i32),
+ connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
+ }
+ .insert(&*tx)
+ .await?;
+ Ok(())
+ })
+ .await
+ }
+
+ pub async fn leave_channel_chat(
+ &self,
+ channel_id: ChannelId,
+ connection_id: ConnectionId,
+ _user_id: UserId,
+ ) -> Result<()> {
+ self.transaction(|tx| async move {
+ channel_chat_participant::Entity::delete_many()
+ .filter(
+ Condition::all()
+ .add(
+ channel_chat_participant::Column::ConnectionServerId
+ .eq(connection_id.owner_id),
+ )
+ .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
+ .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
+ )
+ .exec(&*tx)
+ .await?;
+
+ Ok(())
+ })
+ .await
+ }
+
+ pub async fn get_channel_messages(
+ &self,
+ channel_id: ChannelId,
+ user_id: UserId,
+ count: usize,
+ before_message_id: Option<MessageId>,
+ ) -> Result<Vec<proto::ChannelMessage>> {
+ self.transaction(|tx| async move {
+ self.check_user_is_channel_member(channel_id, user_id, &*tx)
+ .await?;
+
+ let mut condition =
+ Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
+
+ if let Some(before_message_id) = before_message_id {
+ condition = condition.add(channel_message::Column::Id.lt(before_message_id));
+ }
+
+ let mut rows = channel_message::Entity::find()
+ .filter(condition)
+ .limit(count as u64)
+ .stream(&*tx)
+ .await?;
+
+ let mut messages = Vec::new();
+ while let Some(row) = rows.next().await {
+ let row = row?;
+ let nonce = row.nonce.as_u64_pair();
+ messages.push(proto::ChannelMessage {
+ id: row.id.to_proto(),
+ sender_id: row.sender_id.to_proto(),
+ body: row.body,
+ timestamp: row.sent_at.unix_timestamp() as u64,
+ nonce: Some(proto::Nonce {
+ upper_half: nonce.0,
+ lower_half: nonce.1,
+ }),
+ });
+ }
+
+ Ok(messages)
+ })
+ .await
+ }
+
+ pub async fn create_channel_message(
+ &self,
+ channel_id: ChannelId,
+ user_id: UserId,
+ body: &str,
+ timestamp: OffsetDateTime,
+ nonce: u128,
+ ) -> Result<(MessageId, Vec<ConnectionId>)> {
+ self.transaction(|tx| async move {
+ let mut rows = channel_chat_participant::Entity::find()
+ .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
+ .stream(&*tx)
+ .await?;
+
+ let mut is_participant = false;
+ let mut participant_connection_ids = Vec::new();
+ while let Some(row) = rows.next().await {
+ let row = row?;
+ if row.user_id == user_id {
+ is_participant = true;
+ }
+ participant_connection_ids.push(row.connection());
+ }
+ drop(rows);
+
+ if !is_participant {
+ Err(anyhow!("not a chat participant"))?;
+ }
+
+ let message = channel_message::Entity::insert(channel_message::ActiveModel {
+ channel_id: ActiveValue::Set(channel_id),
+ sender_id: ActiveValue::Set(user_id),
+ body: ActiveValue::Set(body.to_string()),
+ sent_at: ActiveValue::Set(timestamp),
+ nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
+ id: ActiveValue::NotSet,
+ })
+ .on_conflict(
+ OnConflict::column(channel_message::Column::Nonce)
+ .update_column(channel_message::Column::Nonce)
+ .to_owned(),
+ )
+ .exec(&*tx)
+ .await?;
+
+ #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
+ enum QueryConnectionId {
+ ConnectionId,
+ }
+
+ Ok((message.last_insert_id, participant_connection_ids))
+ })
+ .await
+ }
+}
@@ -4,7 +4,9 @@ pub mod buffer_operation;
pub mod buffer_snapshot;
pub mod channel;
pub mod channel_buffer_collaborator;
+pub mod channel_chat_participant;
pub mod channel_member;
+pub mod channel_message;
pub mod channel_path;
pub mod contact;
pub mod feature_flag;
@@ -21,6 +21,8 @@ pub enum Relation {
Member,
#[sea_orm(has_many = "super::channel_buffer_collaborator::Entity")]
BufferCollaborators,
+ #[sea_orm(has_many = "super::channel_chat_participant::Entity")]
+ ChatParticipants,
}
impl Related<super::channel_member::Entity> for Entity {
@@ -46,3 +48,9 @@ impl Related<super::channel_buffer_collaborator::Entity> for Entity {
Relation::BufferCollaborators.def()
}
}
+
+impl Related<super::channel_chat_participant::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::ChatParticipants.def()
+ }
+}
@@ -0,0 +1,41 @@
+use crate::db::{ChannelChatParticipantId, ChannelId, ServerId, UserId};
+use rpc::ConnectionId;
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "channel_chat_participants")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub id: ChannelChatParticipantId,
+ pub channel_id: ChannelId,
+ pub user_id: UserId,
+ pub connection_id: i32,
+ pub connection_server_id: ServerId,
+}
+
+impl Model {
+ pub fn connection(&self) -> ConnectionId {
+ ConnectionId {
+ owner_id: self.connection_server_id.0 as u32,
+ id: self.connection_id as u32,
+ }
+ }
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+ #[sea_orm(
+ belongs_to = "super::channel::Entity",
+ from = "Column::ChannelId",
+ to = "super::channel::Column::Id"
+ )]
+ Channel,
+}
+
+impl Related<super::channel::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::Channel.def()
+ }
+}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -0,0 +1,45 @@
+use crate::db::{ChannelId, MessageId, UserId};
+use sea_orm::entity::prelude::*;
+use time::OffsetDateTime;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "channel_messages")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub id: MessageId,
+ pub channel_id: ChannelId,
+ pub sender_id: UserId,
+ pub body: String,
+ pub sent_at: OffsetDateTime,
+ pub nonce: Uuid,
+}
+
+impl ActiveModelBehavior for ActiveModel {}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+ #[sea_orm(
+ belongs_to = "super::channel::Entity",
+ from = "Column::ChannelId",
+ to = "super::channel::Column::Id"
+ )]
+ Channel,
+ #[sea_orm(
+ belongs_to = "super::user::Entity",
+ from = "Column::SenderId",
+ to = "super::user::Column::Id"
+ )]
+ Sender,
+}
+
+impl Related<super::channel::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::Channel.def()
+ }
+}
+
+impl Related<super::user::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::Sender.def()
+ }
+}
@@ -1,6 +1,7 @@
mod buffer_tests;
mod db_tests;
mod feature_flag_tests;
+mod message_tests;
use super::*;
use gpui::executor::Background;
@@ -0,0 +1,53 @@
+use crate::{
+ db::{Database, NewUserParams},
+ test_both_dbs,
+};
+use std::sync::Arc;
+use time::OffsetDateTime;
+
+test_both_dbs!(
+ test_channel_message_nonces,
+ test_channel_message_nonces_postgres,
+ test_channel_message_nonces_sqlite
+);
+
+async fn test_channel_message_nonces(db: &Arc<Database>) {
+ let user = db
+ .create_user(
+ "user@example.com",
+ false,
+ NewUserParams {
+ github_login: "user".into(),
+ github_user_id: 1,
+ invite_count: 0,
+ },
+ )
+ .await
+ .unwrap()
+ .user_id;
+ let channel = db
+ .create_channel("channel", None, "room", user)
+ .await
+ .unwrap();
+
+ let msg1_id = db
+ .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
+ .await
+ .unwrap();
+ let msg2_id = db
+ .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
+ .await
+ .unwrap();
+ let msg3_id = db
+ .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
+ .await
+ .unwrap();
+ let msg4_id = db
+ .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
+ .await
+ .unwrap();
+
+ assert_ne!(msg1_id, msg2_id);
+ assert_eq!(msg1_id, msg3_id);
+ assert_eq!(msg2_id, msg4_id);
+}
@@ -2,7 +2,10 @@ mod connection_pool;
use crate::{
auth,
- db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
+ db::{
+ self, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId, ServerId, User,
+ UserId,
+ },
executor::Executor,
AppState, Result,
};
@@ -56,6 +59,7 @@ use std::{
},
time::{Duration, Instant},
};
+use time::OffsetDateTime;
use tokio::sync::{watch, Semaphore};
use tower::ServiceBuilder;
use tracing::{info_span, instrument, Instrument};
@@ -63,6 +67,9 @@ use tracing::{info_span, instrument, Instrument};
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
+const MESSAGE_COUNT_PER_PAGE: usize = 100;
+const MAX_MESSAGE_LEN: usize = 1024;
+
lazy_static! {
static ref METRIC_CONNECTIONS: IntGauge =
register_int_gauge!("connections", "number of connections").unwrap();
@@ -255,6 +262,10 @@ impl Server {
.add_request_handler(get_channel_members)
.add_request_handler(respond_to_channel_invite)
.add_request_handler(join_channel)
+ .add_request_handler(join_channel_chat)
+ .add_message_handler(leave_channel_chat)
+ .add_request_handler(send_channel_message)
+ .add_request_handler(get_channel_messages)
.add_request_handler(follow)
.add_message_handler(unfollow)
.add_message_handler(update_followers)
@@ -2641,6 +2652,112 @@ fn channel_buffer_updated<T: EnvelopedMessage>(
});
}
+async fn send_channel_message(
+ request: proto::SendChannelMessage,
+ response: Response<proto::SendChannelMessage>,
+ session: Session,
+) -> Result<()> {
+ // Validate the message body.
+ let body = request.body.trim().to_string();
+ if body.len() > MAX_MESSAGE_LEN {
+ return Err(anyhow!("message is too long"))?;
+ }
+ if body.is_empty() {
+ return Err(anyhow!("message can't be blank"))?;
+ }
+
+ let timestamp = OffsetDateTime::now_utc();
+ let nonce = request
+ .nonce
+ .ok_or_else(|| anyhow!("nonce can't be blank"))?;
+
+ let channel_id = ChannelId::from_proto(request.channel_id);
+ let (message_id, connection_ids) = session
+ .db()
+ .await
+ .create_channel_message(
+ channel_id,
+ session.user_id,
+ &body,
+ timestamp,
+ nonce.clone().into(),
+ )
+ .await?;
+ let message = proto::ChannelMessage {
+ sender_id: session.user_id.to_proto(),
+ id: message_id.to_proto(),
+ body,
+ timestamp: timestamp.unix_timestamp() as u64,
+ nonce: Some(nonce),
+ };
+ broadcast(Some(session.connection_id), connection_ids, |connection| {
+ session.peer.send(
+ connection,
+ proto::ChannelMessageSent {
+ channel_id: channel_id.to_proto(),
+ message: Some(message.clone()),
+ },
+ )
+ });
+ response.send(proto::SendChannelMessageResponse {
+ message: Some(message),
+ })?;
+ Ok(())
+}
+
+async fn join_channel_chat(
+ request: proto::JoinChannelChat,
+ response: Response<proto::JoinChannelChat>,
+ session: Session,
+) -> Result<()> {
+ let channel_id = ChannelId::from_proto(request.channel_id);
+
+ let db = session.db().await;
+ db.join_channel_chat(channel_id, session.connection_id, session.user_id)
+ .await?;
+ let messages = db
+ .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
+ .await?;
+ response.send(proto::JoinChannelChatResponse {
+ done: messages.len() < MESSAGE_COUNT_PER_PAGE,
+ messages,
+ })?;
+ Ok(())
+}
+
+async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
+ let channel_id = ChannelId::from_proto(request.channel_id);
+ session
+ .db()
+ .await
+ .leave_channel_chat(channel_id, session.connection_id, session.user_id)
+ .await?;
+ Ok(())
+}
+
+async fn get_channel_messages(
+ request: proto::GetChannelMessages,
+ response: Response<proto::GetChannelMessages>,
+ session: Session,
+) -> Result<()> {
+ let channel_id = ChannelId::from_proto(request.channel_id);
+ let messages = session
+ .db()
+ .await
+ .get_channel_messages(
+ channel_id,
+ session.user_id,
+ MESSAGE_COUNT_PER_PAGE,
+ Some(MessageId::from_proto(request.before_message_id)),
+ )
+ .await?;
+ response.send(proto::GetChannelMessagesResponse {
+ done: messages.len() < MESSAGE_COUNT_PER_PAGE,
+ messages,
+ })?;
+ Ok(())
+}
+
async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id);
let project_connection_ids = session
@@ -2,6 +2,7 @@ use call::Room;
use gpui::{ModelHandle, TestAppContext};
mod channel_buffer_tests;
+mod channel_message_tests;
mod channel_tests;
mod integration_tests;
mod random_channel_buffer_tests;
@@ -0,0 +1,56 @@
+use crate::tests::TestServer;
+use gpui::{executor::Deterministic, TestAppContext};
+use std::sync::Arc;
+
+#[gpui::test]
+async fn test_basic_channel_messages(
+ deterministic: Arc<Deterministic>,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+) {
+ deterministic.forbid_parking();
+ let mut server = TestServer::start(&deterministic).await;
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+
+ let channel_id = server
+ .make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)])
+ .await;
+
+ let channel_chat_a = client_a
+ .channel_store()
+ .update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx))
+ .await
+ .unwrap();
+ let channel_chat_b = client_b
+ .channel_store()
+ .update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx))
+ .await
+ .unwrap();
+
+ channel_chat_a
+ .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap())
+ .await
+ .unwrap();
+ channel_chat_a
+ .update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap())
+ .await
+ .unwrap();
+
+ deterministic.run_until_parked();
+ channel_chat_b
+ .update(cx_b, |c, cx| c.send_message("three".into(), cx).unwrap())
+ .await
+ .unwrap();
+
+ deterministic.run_until_parked();
+ channel_chat_a.update(cx_a, |c, _| {
+ assert_eq!(
+ c.messages()
+ .iter()
+ .map(|m| m.body.as_str())
+ .collect::<Vec<_>>(),
+ vec!["one", "two", "three"]
+ );
+ })
+}