Fix possibility of extra mention insertion on nonce collision

Max Brunsfeld created

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql |   2 
crates/collab/migrations/20231018102700_create_mentions.sql    |   4 
crates/collab/src/db/queries/messages.rs                       |  82 
crates/collab/src/db/tests.rs                                  |  21 
crates/collab/src/db/tests/channel_tests.rs                    |  34 
crates/collab/src/db/tests/message_tests.rs                    | 220 ++-
6 files changed, 198 insertions(+), 165 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -214,7 +214,7 @@ CREATE TABLE IF NOT EXISTS "channel_messages" (
     "nonce" BLOB NOT NULL
 );
 CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id");
-CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce");
+CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce");
 
 CREATE TABLE "channel_message_mentions" (
     "message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE,

crates/collab/migrations/20231018102700_create_mentions.sql 🔗

@@ -5,3 +5,7 @@ CREATE TABLE "channel_message_mentions" (
     "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
     PRIMARY KEY(message_id, start_offset)
 );
+
+-- We use 'on conflict update' with this index, so it should be per-user.
+CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce");
+DROP INDEX "index_channel_messages_on_nonce";

crates/collab/src/db/queries/messages.rs 🔗

@@ -1,4 +1,5 @@
 use super::*;
+use sea_orm::TryInsertResult;
 use time::OffsetDateTime;
 
 impl Database {
@@ -184,7 +185,7 @@ impl Database {
             let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
             let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
 
-            let message_id = channel_message::Entity::insert(channel_message::ActiveModel {
+            let result = channel_message::Entity::insert(channel_message::ActiveModel {
                 channel_id: ActiveValue::Set(channel_id),
                 sender_id: ActiveValue::Set(user_id),
                 body: ActiveValue::Set(body.to_string()),
@@ -193,46 +194,57 @@ impl Database {
                 id: ActiveValue::NotSet,
             })
             .on_conflict(
-                OnConflict::column(channel_message::Column::Nonce)
-                    .update_column(channel_message::Column::Nonce)
-                    .to_owned(),
+                OnConflict::columns([
+                    channel_message::Column::SenderId,
+                    channel_message::Column::Nonce,
+                ])
+                .do_nothing()
+                .to_owned(),
             )
+            .do_nothing()
             .exec(&*tx)
-            .await?
-            .last_insert_id;
-
-            let models = mentions
-                .iter()
-                .filter_map(|mention| {
-                    let range = mention.range.as_ref()?;
-                    if !body.is_char_boundary(range.start as usize)
-                        || !body.is_char_boundary(range.end as usize)
-                    {
-                        return None;
+            .await?;
+
+            let message_id;
+            match result {
+                TryInsertResult::Inserted(result) => {
+                    message_id = result.last_insert_id;
+                    let models = mentions
+                        .iter()
+                        .filter_map(|mention| {
+                            let range = mention.range.as_ref()?;
+                            if !body.is_char_boundary(range.start as usize)
+                                || !body.is_char_boundary(range.end as usize)
+                            {
+                                return None;
+                            }
+                            Some(channel_message_mention::ActiveModel {
+                                message_id: ActiveValue::Set(message_id),
+                                start_offset: ActiveValue::Set(range.start as i32),
+                                end_offset: ActiveValue::Set(range.end as i32),
+                                user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
+                            })
+                        })
+                        .collect::<Vec<_>>();
+                    if !models.is_empty() {
+                        channel_message_mention::Entity::insert_many(models)
+                            .exec(&*tx)
+                            .await?;
                     }
-                    Some(channel_message_mention::ActiveModel {
-                        message_id: ActiveValue::Set(message_id),
-                        start_offset: ActiveValue::Set(range.start as i32),
-                        end_offset: ActiveValue::Set(range.end as i32),
-                        user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
-                    })
-                })
-                .collect::<Vec<_>>();
-            if !models.is_empty() {
-                channel_message_mention::Entity::insert_many(models)
-                    .exec(&*tx)
-                    .await?;
-            }
 
-            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
-            enum QueryConnectionId {
-                ConnectionId,
+                    self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
+                        .await?;
+                }
+                _ => {
+                    message_id = channel_message::Entity::find()
+                        .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
+                        .one(&*tx)
+                        .await?
+                        .ok_or_else(|| anyhow!("failed to insert message"))?
+                        .id;
+                }
             }
 
-            // Observe this message for the sender
-            self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
-                .await?;
-
             let mut channel_members = self
                 .get_channel_participants_internal(channel_id, &*tx)
                 .await?;

crates/collab/src/db/tests.rs 🔗

@@ -10,7 +10,10 @@ use parking_lot::Mutex;
 use rpc::proto::ChannelEdge;
 use sea_orm::ConnectionTrait;
 use sqlx::migrate::MigrateDatabase;
-use std::sync::Arc;
+use std::sync::{
+    atomic::{AtomicI32, Ordering::SeqCst},
+    Arc,
+};
 
 const TEST_RELEASE_CHANNEL: &'static str = "test";
 
@@ -174,3 +177,19 @@ fn graph(channels: &[(ChannelId, &'static str)], edges: &[(ChannelId, ChannelId)
 
     graph
 }
+
+static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
+
+async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
+    db.create_user(
+        email,
+        false,
+        NewUserParams {
+            github_login: email[0..email.find("@").unwrap()].to_string(),
+            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
+        },
+    )
+    .await
+    .unwrap()
+    .user_id
+}

crates/collab/src/db/tests/channel_tests.rs 🔗

@@ -1,21 +1,17 @@
-use collections::{HashMap, HashSet};
-use rpc::{
-    proto::{self},
-    ConnectionId,
-};
-
 use crate::{
     db::{
         queries::channels::ChannelGraph,
-        tests::{graph, TEST_RELEASE_CHANNEL},
-        ChannelId, ChannelRole, Database, NewUserParams, RoomId, UserId,
+        tests::{graph, new_test_user, TEST_RELEASE_CHANNEL},
+        ChannelId, ChannelRole, Database, NewUserParams, RoomId,
     },
     test_both_dbs,
 };
-use std::sync::{
-    atomic::{AtomicI32, Ordering},
-    Arc,
+use collections::{HashMap, HashSet};
+use rpc::{
+    proto::{self},
+    ConnectionId,
 };
+use std::sync::Arc;
 
 test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite);
 
@@ -1105,19 +1101,3 @@ fn assert_dag(actual: ChannelGraph, expected: &[(ChannelId, Option<ChannelId>)])
 
     pretty_assertions::assert_eq!(actual_map, expected_map)
 }
-
-static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
-
-async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
-    db.create_user(
-        email,
-        false,
-        NewUserParams {
-            github_login: email[0..email.find("@").unwrap()].to_string(),
-            github_user_id: GITHUB_USER_ID.fetch_add(1, Ordering::SeqCst),
-        },
-    )
-    .await
-    .unwrap()
-    .user_id
-}

crates/collab/src/db/tests/message_tests.rs 🔗

@@ -1,5 +1,6 @@
+use super::new_test_user;
 use crate::{
-    db::{ChannelRole, Database, MessageId, NewUserParams},
+    db::{ChannelRole, Database, MessageId},
     test_both_dbs,
 };
 use channel::mentions_to_proto;
@@ -13,18 +14,7 @@ test_both_dbs!(
 );
 
 async fn test_channel_message_retrieval(db: &Arc<Database>) {
-    let user = db
-        .create_user(
-            "user@example.com",
-            false,
-            NewUserParams {
-                github_login: "user".into(),
-                github_user_id: 1,
-            },
-        )
-        .await
-        .unwrap()
-        .user_id;
+    let user = new_test_user(db, "user@example.com").await;
     let channel = db.create_channel("channel", None, user).await.unwrap();
 
     let owner_id = db.create_server("test").await.unwrap().0 as u32;
@@ -81,46 +71,129 @@ test_both_dbs!(
 );
 
 async fn test_channel_message_nonces(db: &Arc<Database>) {
-    let user = db
-        .create_user(
-            "user@example.com",
-            false,
-            NewUserParams {
-                github_login: "user".into(),
-                github_user_id: 1,
-            },
-        )
+    let user_a = new_test_user(db, "user_a@example.com").await;
+    let user_b = new_test_user(db, "user_b@example.com").await;
+    let user_c = new_test_user(db, "user_c@example.com").await;
+    let channel = db.create_channel("channel", None, user_a).await.unwrap();
+    db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member)
         .await
-        .unwrap()
-        .user_id;
-    let channel = db.create_channel("channel", None, user).await.unwrap();
-
-    let owner_id = db.create_server("test").await.unwrap().0 as u32;
-
-    db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user)
+        .unwrap();
+    db.invite_channel_member(channel, user_c, user_a, ChannelRole::Member)
         .await
         .unwrap();
-
-    let msg1_id = db
-        .create_channel_message(channel, user, "1", &[], OffsetDateTime::now_utc(), 1)
+    db.respond_to_channel_invite(channel, user_b, true)
         .await
         .unwrap();
-    let msg2_id = db
-        .create_channel_message(channel, user, "2", &[], OffsetDateTime::now_utc(), 2)
+    db.respond_to_channel_invite(channel, user_c, true)
         .await
         .unwrap();
-    let msg3_id = db
-        .create_channel_message(channel, user, "3", &[], OffsetDateTime::now_utc(), 1)
+
+    let owner_id = db.create_server("test").await.unwrap().0 as u32;
+    db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user_a)
         .await
         .unwrap();
-    let msg4_id = db
-        .create_channel_message(channel, user, "4", &[], OffsetDateTime::now_utc(), 2)
+    db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 1 }, user_b)
         .await
         .unwrap();
 
-    assert_ne!(msg1_id, msg2_id);
-    assert_eq!(msg1_id, msg3_id);
-    assert_eq!(msg2_id, msg4_id);
+    // As user A, create messages that re-use the same nonces. The requests
+    // succeed, but return the same ids.
+    let id1 = db
+        .create_channel_message(
+            channel,
+            user_a,
+            "hi @user_b",
+            &mentions_to_proto(&[(3..10, user_b.to_proto())]),
+            OffsetDateTime::now_utc(),
+            100,
+        )
+        .await
+        .unwrap()
+        .0;
+    let id2 = db
+        .create_channel_message(
+            channel,
+            user_a,
+            "hello, fellow users",
+            &mentions_to_proto(&[]),
+            OffsetDateTime::now_utc(),
+            200,
+        )
+        .await
+        .unwrap()
+        .0;
+    let id3 = db
+        .create_channel_message(
+            channel,
+            user_a,
+            "bye @user_c (same nonce as first message)",
+            &mentions_to_proto(&[(4..11, user_c.to_proto())]),
+            OffsetDateTime::now_utc(),
+            100,
+        )
+        .await
+        .unwrap()
+        .0;
+    let id4 = db
+        .create_channel_message(
+            channel,
+            user_a,
+            "omg (same nonce as second message)",
+            &mentions_to_proto(&[]),
+            OffsetDateTime::now_utc(),
+            200,
+        )
+        .await
+        .unwrap()
+        .0;
+
+    // As a different user, reuse one of the same nonces. This request succeeds
+    // and returns a different id.
+    let id5 = db
+        .create_channel_message(
+            channel,
+            user_b,
+            "omg @user_a (same nonce as user_a's first message)",
+            &mentions_to_proto(&[(4..11, user_a.to_proto())]),
+            OffsetDateTime::now_utc(),
+            100,
+        )
+        .await
+        .unwrap()
+        .0;
+
+    assert_ne!(id1, id2);
+    assert_eq!(id1, id3);
+    assert_eq!(id2, id4);
+    assert_ne!(id5, id1);
+
+    let messages = db
+        .get_channel_messages(channel, user_a, 5, None)
+        .await
+        .unwrap()
+        .into_iter()
+        .map(|m| (m.id, m.body, m.mentions))
+        .collect::<Vec<_>>();
+    assert_eq!(
+        messages,
+        &[
+            (
+                id1.to_proto(),
+                "hi @user_b".into(),
+                mentions_to_proto(&[(3..10, user_b.to_proto())]),
+            ),
+            (
+                id2.to_proto(),
+                "hello, fellow users".into(),
+                mentions_to_proto(&[])
+            ),
+            (
+                id5.to_proto(),
+                "omg @user_a (same nonce as user_a's first message)".into(),
+                mentions_to_proto(&[(4..11, user_a.to_proto())]),
+            ),
+        ]
+    );
 }
 
 test_both_dbs!(
@@ -130,30 +203,8 @@ test_both_dbs!(
 );
 
 async fn test_unseen_channel_messages(db: &Arc<Database>) {
-    let user = db
-        .create_user(
-            "user_a@example.com",
-            false,
-            NewUserParams {
-                github_login: "user_a".into(),
-                github_user_id: 1,
-            },
-        )
-        .await
-        .unwrap()
-        .user_id;
-    let observer = db
-        .create_user(
-            "user_b@example.com",
-            false,
-            NewUserParams {
-                github_login: "user_b".into(),
-                github_user_id: 2,
-            },
-        )
-        .await
-        .unwrap()
-        .user_id;
+    let user = new_test_user(db, "user_a@example.com").await;
+    let observer = new_test_user(db, "user_b@example.com").await;
 
     let channel_1 = db.create_channel("channel", None, user).await.unwrap();
     let channel_2 = db.create_channel("channel-2", None, user).await.unwrap();
@@ -304,42 +355,9 @@ test_both_dbs!(
 );
 
 async fn test_channel_message_mentions(db: &Arc<Database>) {
-    let user_a = db
-        .create_user(
-            "user_a@example.com",
-            false,
-            NewUserParams {
-                github_login: "user_a".into(),
-                github_user_id: 1,
-            },
-        )
-        .await
-        .unwrap()
-        .user_id;
-    let user_b = db
-        .create_user(
-            "user_b@example.com",
-            false,
-            NewUserParams {
-                github_login: "user_b".into(),
-                github_user_id: 2,
-            },
-        )
-        .await
-        .unwrap()
-        .user_id;
-    let user_c = db
-        .create_user(
-            "user_b@example.com",
-            false,
-            NewUserParams {
-                github_login: "user_c".into(),
-                github_user_id: 3,
-            },
-        )
-        .await
-        .unwrap()
-        .user_id;
+    let user_a = new_test_user(db, "user_a@example.com").await;
+    let user_b = new_test_user(db, "user_b@example.com").await;
+    let user_c = new_test_user(db, "user_c@example.com").await;
 
     let channel = db.create_channel("channel", None, user_a).await.unwrap();
     db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member)