Start work on restoring server-side code for chat messages

Max Brunsfeld created

Change summary

crates/channel/src/channel.rs                                    |   1 
crates/channel/src/channel_chat.rs                               |   8 
crates/collab/migrations.sqlite/20221109000000_test_schema.sql   |  20 
crates/collab/migrations/20230907114200_add_channel_messages.sql |  19 
crates/collab/src/db/ids.rs                                      |   2 
crates/collab/src/db/queries.rs                                  |   1 
crates/collab/src/db/queries/messages.rs                         | 152 ++
crates/collab/src/db/tables.rs                                   |   2 
crates/collab/src/db/tables/channel.rs                           |   8 
crates/collab/src/db/tables/channel_chat_participant.rs          |  41 
crates/collab/src/db/tables/channel_message.rs                   |  45 
crates/collab/src/db/tests.rs                                    |   1 
crates/collab/src/db/tests/message_tests.rs                      |  53 
crates/collab/src/rpc.rs                                         | 119 +
crates/collab/src/tests.rs                                       |   1 
crates/collab/src/tests/channel_message_tests.rs                 |  56 
16 files changed, 524 insertions(+), 5 deletions(-)

Detailed changes

crates/channel/src/channel.rs 🔗

@@ -14,4 +14,5 @@ mod channel_store_tests;
 
 pub fn init(client: &Arc<Client>) {
     channel_buffer::init(client);
+    channel_chat::init(client);
 }

crates/channel/src/channel_chat.rs 🔗

@@ -57,6 +57,10 @@ pub enum ChannelChatEvent {
     },
 }
 
+pub fn init(client: &Arc<Client>) {
+    client.add_model_message_handler(ChannelChat::handle_message_sent);
+}
+
 impl Entity for ChannelChat {
     type Event = ChannelChatEvent;
 
@@ -70,10 +74,6 @@ impl Entity for ChannelChat {
 }
 
 impl ChannelChat {
-    pub fn init(rpc: &Arc<Client>) {
-        rpc.add_model_message_handler(Self::handle_message_sent);
-    }
-
     pub async fn new(
         channel: Arc<Channel>,
         user_store: ModelHandle<UserStore>,

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -192,6 +192,26 @@ CREATE TABLE "channels" (
     "created_at" TIMESTAMP NOT NULL DEFAULT now
 );
 
+CREATE TABLE IF NOT EXISTS "channel_chat_participants" (
+    "id" INTEGER PRIMARY KEY AUTOINCREMENT,
+    "user_id" INTEGER NOT NULL REFERENCES users (id),
+    "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
+    "connection_id" INTEGER NOT NULL,
+    "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE
+);
+CREATE INDEX "index_channel_chat_participants_on_channel_id" ON "channel_chat_participants" ("channel_id");
+
+CREATE TABLE IF NOT EXISTS "channel_messages" (
+    "id" INTEGER PRIMARY KEY AUTOINCREMENT,
+    "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
+    "sender_id" INTEGER NOT NULL REFERENCES users (id),
+    "body" TEXT NOT NULL,
+    "sent_at" TIMESTAMP,
+    "nonce" BLOB NOT NULL
+);
+CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id");
+CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce");
+
 CREATE TABLE "channel_paths" (
     "id_path" TEXT NOT NULL PRIMARY KEY,
     "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE

crates/collab/migrations/20230907114200_add_channel_messages.sql 🔗

@@ -0,0 +1,19 @@
+CREATE TABLE IF NOT EXISTS "channel_messages" (
+    "id" SERIAL PRIMARY KEY,
+    "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
+    "sender_id" INTEGER NOT NULL REFERENCES users (id),
+    "body" TEXT NOT NULL,
+    "sent_at" TIMESTAMP,
+    "nonce" UUID NOT NULL
+);
+CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id");
+CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce");
+
+CREATE TABLE IF NOT EXISTS "channel_chat_participants" (
+    "id" SERIAL PRIMARY KEY,
+    "user_id" INTEGER NOT NULL REFERENCES users (id),
+    "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
+    "connection_id" INTEGER NOT NULL,
+    "connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE
+);
+CREATE INDEX "index_channel_chat_participants_on_channel_id" ON "channel_chat_participants" ("channel_id");

crates/collab/src/db/ids.rs 🔗

@@ -112,8 +112,10 @@ fn value_to_integer(v: Value) -> Result<i32, ValueTypeErr> {
 
 id_type!(BufferId);
 id_type!(AccessTokenId);
+id_type!(ChannelChatParticipantId);
 id_type!(ChannelId);
 id_type!(ChannelMemberId);
+id_type!(MessageId);
 id_type!(ContactId);
 id_type!(FollowerId);
 id_type!(RoomId);

crates/collab/src/db/queries.rs 🔗

@@ -4,6 +4,7 @@ pub mod access_tokens;
 pub mod buffers;
 pub mod channels;
 pub mod contacts;
+pub mod messages;
 pub mod projects;
 pub mod rooms;
 pub mod servers;

crates/collab/src/db/queries/messages.rs 🔗

@@ -0,0 +1,152 @@
+use super::*;
+use time::OffsetDateTime;
+
+impl Database {
+    pub async fn join_channel_chat(
+        &self,
+        channel_id: ChannelId,
+        connection_id: ConnectionId,
+        user_id: UserId,
+    ) -> Result<()> {
+        self.transaction(|tx| async move {
+            self.check_user_is_channel_member(channel_id, user_id, &*tx)
+                .await?;
+            channel_chat_participant::ActiveModel {
+                id: ActiveValue::NotSet,
+                channel_id: ActiveValue::Set(channel_id),
+                user_id: ActiveValue::Set(user_id),
+                connection_id: ActiveValue::Set(connection_id.id as i32),
+                connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
+            }
+            .insert(&*tx)
+            .await?;
+            Ok(())
+        })
+        .await
+    }
+
+    pub async fn leave_channel_chat(
+        &self,
+        channel_id: ChannelId,
+        connection_id: ConnectionId,
+        _user_id: UserId,
+    ) -> Result<()> {
+        self.transaction(|tx| async move {
+            channel_chat_participant::Entity::delete_many()
+                .filter(
+                    Condition::all()
+                        .add(
+                            channel_chat_participant::Column::ConnectionServerId
+                                .eq(connection_id.owner_id),
+                        )
+                        .add(channel_chat_participant::Column::ConnectionId.eq(connection_id.id))
+                        .add(channel_chat_participant::Column::ChannelId.eq(channel_id)),
+                )
+                .exec(&*tx)
+                .await?;
+
+            Ok(())
+        })
+        .await
+    }
+
+    pub async fn get_channel_messages(
+        &self,
+        channel_id: ChannelId,
+        user_id: UserId,
+        count: usize,
+        before_message_id: Option<MessageId>,
+    ) -> Result<Vec<proto::ChannelMessage>> {
+        self.transaction(|tx| async move {
+            self.check_user_is_channel_member(channel_id, user_id, &*tx)
+                .await?;
+
+            let mut condition =
+                Condition::all().add(channel_message::Column::ChannelId.eq(channel_id));
+
+            if let Some(before_message_id) = before_message_id {
+                condition = condition.add(channel_message::Column::Id.lt(before_message_id));
+            }
+
+            let mut rows = channel_message::Entity::find()
+                .filter(condition)
+                .limit(count as u64)
+                .stream(&*tx)
+                .await?;
+
+            let mut messages = Vec::new();
+            while let Some(row) = rows.next().await {
+                let row = row?;
+                let nonce = row.nonce.as_u64_pair();
+                messages.push(proto::ChannelMessage {
+                    id: row.id.to_proto(),
+                    sender_id: row.sender_id.to_proto(),
+                    body: row.body,
+                    timestamp: row.sent_at.unix_timestamp() as u64,
+                    nonce: Some(proto::Nonce {
+                        upper_half: nonce.0,
+                        lower_half: nonce.1,
+                    }),
+                });
+            }
+
+            Ok(messages)
+        })
+        .await
+    }
+
+    pub async fn create_channel_message(
+        &self,
+        channel_id: ChannelId,
+        user_id: UserId,
+        body: &str,
+        timestamp: OffsetDateTime,
+        nonce: u128,
+    ) -> Result<(MessageId, Vec<ConnectionId>)> {
+        self.transaction(|tx| async move {
+            let mut rows = channel_chat_participant::Entity::find()
+                .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
+                .stream(&*tx)
+                .await?;
+
+            let mut is_participant = false;
+            let mut participant_connection_ids = Vec::new();
+            while let Some(row) = rows.next().await {
+                let row = row?;
+                if row.user_id == user_id {
+                    is_participant = true;
+                }
+                participant_connection_ids.push(row.connection());
+            }
+            drop(rows);
+
+            if !is_participant {
+                Err(anyhow!("not a chat participant"))?;
+            }
+
+            let message = channel_message::Entity::insert(channel_message::ActiveModel {
+                channel_id: ActiveValue::Set(channel_id),
+                sender_id: ActiveValue::Set(user_id),
+                body: ActiveValue::Set(body.to_string()),
+                sent_at: ActiveValue::Set(timestamp),
+                nonce: ActiveValue::Set(Uuid::from_u128(nonce)),
+                id: ActiveValue::NotSet,
+            })
+            .on_conflict(
+                OnConflict::column(channel_message::Column::Nonce)
+                    .update_column(channel_message::Column::Nonce)
+                    .to_owned(),
+            )
+            .exec(&*tx)
+            .await?;
+
+            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
+            enum QueryConnectionId {
+                ConnectionId,
+            }
+
+            Ok((message.last_insert_id, participant_connection_ids))
+        })
+        .await
+    }
+}

crates/collab/src/db/tables.rs 🔗

@@ -4,7 +4,9 @@ pub mod buffer_operation;
 pub mod buffer_snapshot;
 pub mod channel;
 pub mod channel_buffer_collaborator;
+pub mod channel_chat_participant;
 pub mod channel_member;
+pub mod channel_message;
 pub mod channel_path;
 pub mod contact;
 pub mod feature_flag;

crates/collab/src/db/tables/channel.rs 🔗

@@ -21,6 +21,8 @@ pub enum Relation {
     Member,
     #[sea_orm(has_many = "super::channel_buffer_collaborator::Entity")]
     BufferCollaborators,
+    #[sea_orm(has_many = "super::channel_chat_participant::Entity")]
+    ChatParticipants,
 }
 
 impl Related<super::channel_member::Entity> for Entity {
@@ -46,3 +48,9 @@ impl Related<super::channel_buffer_collaborator::Entity> for Entity {
         Relation::BufferCollaborators.def()
     }
 }
+
+impl Related<super::channel_chat_participant::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::ChatParticipants.def()
+    }
+}

crates/collab/src/db/tables/channel_chat_participant.rs 🔗

@@ -0,0 +1,41 @@
+use crate::db::{ChannelChatParticipantId, ChannelId, ServerId, UserId};
+use rpc::ConnectionId;
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "channel_chat_participants")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub id: ChannelChatParticipantId,
+    pub channel_id: ChannelId,
+    pub user_id: UserId,
+    pub connection_id: i32,
+    pub connection_server_id: ServerId,
+}
+
+impl Model {
+    pub fn connection(&self) -> ConnectionId {
+        ConnectionId {
+            owner_id: self.connection_server_id.0 as u32,
+            id: self.connection_id as u32,
+        }
+    }
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+    #[sea_orm(
+        belongs_to = "super::channel::Entity",
+        from = "Column::ChannelId",
+        to = "super::channel::Column::Id"
+    )]
+    Channel,
+}
+
+impl Related<super::channel::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::Channel.def()
+    }
+}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/db/tables/channel_message.rs 🔗

@@ -0,0 +1,45 @@
+use crate::db::{ChannelId, MessageId, UserId};
+use sea_orm::entity::prelude::*;
+use time::OffsetDateTime;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "channel_messages")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub id: MessageId,
+    pub channel_id: ChannelId,
+    pub sender_id: UserId,
+    pub body: String,
+    pub sent_at: OffsetDateTime,
+    pub nonce: Uuid,
+}
+
+impl ActiveModelBehavior for ActiveModel {}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+    #[sea_orm(
+        belongs_to = "super::channel::Entity",
+        from = "Column::ChannelId",
+        to = "super::channel::Column::Id"
+    )]
+    Channel,
+    #[sea_orm(
+        belongs_to = "super::user::Entity",
+        from = "Column::SenderId",
+        to = "super::user::Column::Id"
+    )]
+    Sender,
+}
+
+impl Related<super::channel::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::Channel.def()
+    }
+}
+
+impl Related<super::user::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::Sender.def()
+    }
+}

crates/collab/src/db/tests.rs 🔗

@@ -1,6 +1,7 @@
 mod buffer_tests;
 mod db_tests;
 mod feature_flag_tests;
+mod message_tests;
 
 use super::*;
 use gpui::executor::Background;

crates/collab/src/db/tests/message_tests.rs 🔗

@@ -0,0 +1,53 @@
+use crate::{
+    db::{Database, NewUserParams},
+    test_both_dbs,
+};
+use std::sync::Arc;
+use time::OffsetDateTime;
+
+test_both_dbs!(
+    test_channel_message_nonces,
+    test_channel_message_nonces_postgres,
+    test_channel_message_nonces_sqlite
+);
+
+async fn test_channel_message_nonces(db: &Arc<Database>) {
+    let user = db
+        .create_user(
+            "user@example.com",
+            false,
+            NewUserParams {
+                github_login: "user".into(),
+                github_user_id: 1,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
+    let channel = db
+        .create_channel("channel", None, "room", user)
+        .await
+        .unwrap();
+
+    let msg1_id = db
+        .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
+        .await
+        .unwrap();
+    let msg2_id = db
+        .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
+        .await
+        .unwrap();
+    let msg3_id = db
+        .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
+        .await
+        .unwrap();
+    let msg4_id = db
+        .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
+        .await
+        .unwrap();
+
+    assert_ne!(msg1_id, msg2_id);
+    assert_eq!(msg1_id, msg3_id);
+    assert_eq!(msg2_id, msg4_id);
+}

crates/collab/src/rpc.rs 🔗

@@ -2,7 +2,10 @@ mod connection_pool;
 
 use crate::{
     auth,
-    db::{self, ChannelId, ChannelsForUser, Database, ProjectId, RoomId, ServerId, User, UserId},
+    db::{
+        self, ChannelId, ChannelsForUser, Database, MessageId, ProjectId, RoomId, ServerId, User,
+        UserId,
+    },
     executor::Executor,
     AppState, Result,
 };
@@ -56,6 +59,7 @@ use std::{
     },
     time::{Duration, Instant},
 };
+use time::OffsetDateTime;
 use tokio::sync::{watch, Semaphore};
 use tower::ServiceBuilder;
 use tracing::{info_span, instrument, Instrument};
@@ -63,6 +67,9 @@ use tracing::{info_span, instrument, Instrument};
 pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
 pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
 
+const MESSAGE_COUNT_PER_PAGE: usize = 100;
+const MAX_MESSAGE_LEN: usize = 1024;
+
 lazy_static! {
     static ref METRIC_CONNECTIONS: IntGauge =
         register_int_gauge!("connections", "number of connections").unwrap();
@@ -255,6 +262,10 @@ impl Server {
             .add_request_handler(get_channel_members)
             .add_request_handler(respond_to_channel_invite)
             .add_request_handler(join_channel)
+            .add_request_handler(join_channel_chat)
+            .add_message_handler(leave_channel_chat)
+            .add_request_handler(send_channel_message)
+            .add_request_handler(get_channel_messages)
             .add_request_handler(follow)
             .add_message_handler(unfollow)
             .add_message_handler(update_followers)
@@ -2641,6 +2652,112 @@ fn channel_buffer_updated<T: EnvelopedMessage>(
     });
 }
 
+async fn send_channel_message(
+    request: proto::SendChannelMessage,
+    response: Response<proto::SendChannelMessage>,
+    session: Session,
+) -> Result<()> {
+    // Validate the message body.
+    let body = request.body.trim().to_string();
+    if body.len() > MAX_MESSAGE_LEN {
+        return Err(anyhow!("message is too long"))?;
+    }
+    if body.is_empty() {
+        return Err(anyhow!("message can't be blank"))?;
+    }
+
+    let timestamp = OffsetDateTime::now_utc();
+    let nonce = request
+        .nonce
+        .ok_or_else(|| anyhow!("nonce can't be blank"))?;
+
+    let channel_id = ChannelId::from_proto(request.channel_id);
+    let (message_id, connection_ids) = session
+        .db()
+        .await
+        .create_channel_message(
+            channel_id,
+            session.user_id,
+            &body,
+            timestamp,
+            nonce.clone().into(),
+        )
+        .await?;
+    let message = proto::ChannelMessage {
+        sender_id: session.user_id.to_proto(),
+        id: message_id.to_proto(),
+        body,
+        timestamp: timestamp.unix_timestamp() as u64,
+        nonce: Some(nonce),
+    };
+    broadcast(Some(session.connection_id), connection_ids, |connection| {
+        session.peer.send(
+            connection,
+            proto::ChannelMessageSent {
+                channel_id: channel_id.to_proto(),
+                message: Some(message.clone()),
+            },
+        )
+    });
+    response.send(proto::SendChannelMessageResponse {
+        message: Some(message),
+    })?;
+    Ok(())
+}
+
+async fn join_channel_chat(
+    request: proto::JoinChannelChat,
+    response: Response<proto::JoinChannelChat>,
+    session: Session,
+) -> Result<()> {
+    let channel_id = ChannelId::from_proto(request.channel_id);
+
+    let db = session.db().await;
+    db.join_channel_chat(channel_id, session.connection_id, session.user_id)
+        .await?;
+    let messages = db
+        .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
+        .await?;
+    response.send(proto::JoinChannelChatResponse {
+        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
+        messages,
+    })?;
+    Ok(())
+}
+
+async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
+    let channel_id = ChannelId::from_proto(request.channel_id);
+    session
+        .db()
+        .await
+        .leave_channel_chat(channel_id, session.connection_id, session.user_id)
+        .await?;
+    Ok(())
+}
+
+async fn get_channel_messages(
+    request: proto::GetChannelMessages,
+    response: Response<proto::GetChannelMessages>,
+    session: Session,
+) -> Result<()> {
+    let channel_id = ChannelId::from_proto(request.channel_id);
+    let messages = session
+        .db()
+        .await
+        .get_channel_messages(
+            channel_id,
+            session.user_id,
+            MESSAGE_COUNT_PER_PAGE,
+            Some(MessageId::from_proto(request.before_message_id)),
+        )
+        .await?;
+    response.send(proto::GetChannelMessagesResponse {
+        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
+        messages,
+    })?;
+    Ok(())
+}
+
 async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
     let project_id = ProjectId::from_proto(request.project_id);
     let project_connection_ids = session

crates/collab/src/tests.rs 🔗

@@ -2,6 +2,7 @@ use call::Room;
 use gpui::{ModelHandle, TestAppContext};
 
 mod channel_buffer_tests;
+mod channel_message_tests;
 mod channel_tests;
 mod integration_tests;
 mod random_channel_buffer_tests;

crates/collab/src/tests/channel_message_tests.rs 🔗

@@ -0,0 +1,56 @@
+use crate::tests::TestServer;
+use gpui::{executor::Deterministic, TestAppContext};
+use std::sync::Arc;
+
+#[gpui::test]
+async fn test_basic_channel_messages(
+    deterministic: Arc<Deterministic>,
+    cx_a: &mut TestAppContext,
+    cx_b: &mut TestAppContext,
+) {
+    deterministic.forbid_parking();
+    let mut server = TestServer::start(&deterministic).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_b").await;
+
+    let channel_id = server
+        .make_channel("the-channel", (&client_a, cx_a), &mut [(&client_b, cx_b)])
+        .await;
+
+    let channel_chat_a = client_a
+        .channel_store()
+        .update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx))
+        .await
+        .unwrap();
+    let channel_chat_b = client_b
+        .channel_store()
+        .update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx))
+        .await
+        .unwrap();
+
+    channel_chat_a
+        .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap())
+        .await
+        .unwrap();
+    channel_chat_a
+        .update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap())
+        .await
+        .unwrap();
+
+    deterministic.run_until_parked();
+    channel_chat_b
+        .update(cx_b, |c, cx| c.send_message("three".into(), cx).unwrap())
+        .await
+        .unwrap();
+
+    deterministic.run_until_parked();
+    channel_chat_a.update(cx_a, |c, _| {
+        assert_eq!(
+            c.messages()
+                .iter()
+                .map(|m| m.body.as_str())
+                .collect::<Vec<_>>(),
+            vec!["one", "two", "three"]
+        );
+    })
+}