Re-send pending messages after reconnecting

Antonio Scandurra created

Change summary

Cargo.lock             |  13 +++-
server/Cargo.toml      |   5 +
server/src/bin/seed.rs |   2 
server/src/db.rs       |  44 ++++++++++++++-
server/src/rpc.rs      |  45 +++++++++++++++
zed/src/channel.rs     | 119 ++++++++++++++++++++++++++++++++-----------
zrpc/proto/zed.proto   |   7 ++
zrpc/src/proto.rs      |  19 +++++++
8 files changed, 211 insertions(+), 43 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -836,7 +836,7 @@ dependencies = [
  "target_build_utils",
  "term",
  "toml 0.4.10",
- "uuid",
+ "uuid 0.5.1",
  "walkdir",
 ]
 
@@ -884,7 +884,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "8e7fb075b9b54e939006aa12e1f6cd2d3194041ff4ebe7f2efcbedf18f25b667"
 dependencies = [
  "byteorder",
- "uuid",
+ "uuid 0.5.1",
 ]
 
 [[package]]
@@ -2963,7 +2963,7 @@ dependencies = [
  "byteorder",
  "cfb",
  "encoding",
- "uuid",
+ "uuid 0.5.1",
 ]
 
 [[package]]
@@ -4784,6 +4784,7 @@ dependencies = [
  "thiserror",
  "time 0.2.25",
  "url",
+ "uuid 0.8.2",
  "webpki",
  "webpki-roots",
  "whoami",
@@ -5606,6 +5607,12 @@ dependencies = [
  "sha1 0.2.0",
 ]
 
+[[package]]
+name = "uuid"
+version = "0.8.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7"
+
 [[package]]
 name = "value-bag"
 version = "1.0.0-alpha.7"

server/Cargo.toml 🔗

@@ -5,6 +5,9 @@ edition = "2018"
 name = "zed-server"
 version = "0.1.0"
 
+[[bin]]
+name = "zed-server"
+
 [[bin]]
 name = "seed"
 required-features = ["seed-support"]
@@ -47,7 +50,7 @@ default-features = false
 
 [dependencies.sqlx]
 version = "0.5.2"
-features = ["runtime-async-std-rustls", "postgres", "time"]
+features = ["runtime-async-std-rustls", "postgres", "time", "uuid"]
 
 [dev-dependencies]
 gpui = { path = "../gpui" }

server/src/bin/seed.rs 🔗

@@ -73,7 +73,7 @@ async fn main() {
         for timestamp in timestamps {
             let sender_id = *zed_user_ids.choose(&mut rng).unwrap();
             let body = lipsum::lipsum_words(rng.gen_range(1..=50));
-            db.create_channel_message(channel_id, sender_id, &body, timestamp)
+            db.create_channel_message(channel_id, sender_id, &body, timestamp, rng.gen())
                 .await
                 .expect("failed to insert message");
         }

server/src/db.rs 🔗

@@ -1,7 +1,7 @@
 use anyhow::Context;
 use async_std::task::{block_on, yield_now};
 use serde::Serialize;
-use sqlx::{FromRow, Result};
+use sqlx::{types::Uuid, FromRow, Result};
 use time::OffsetDateTime;
 
 pub use async_sqlx_session::PostgresSessionStore as SessionStore;
@@ -402,11 +402,13 @@ impl Db {
         sender_id: UserId,
         body: &str,
         timestamp: OffsetDateTime,
+        nonce: u128,
     ) -> Result<MessageId> {
         test_support!(self, {
             let query = "
-                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
-                VALUES ($1, $2, $3, $4)
+                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
+                VALUES ($1, $2, $3, $4, $5)
+                ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
                 RETURNING id
             ";
             sqlx::query_scalar(query)
@@ -414,6 +416,7 @@ impl Db {
                 .bind(sender_id.0)
                 .bind(body)
                 .bind(timestamp)
+                .bind(Uuid::from_u128(nonce))
                 .fetch_one(&self.pool)
                 .await
                 .map(MessageId)
@@ -430,7 +433,7 @@ impl Db {
             let query = r#"
                 SELECT * FROM (
                     SELECT
-                        id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
+                        id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
                     FROM
                         channel_messages
                     WHERE
@@ -514,6 +517,7 @@ pub struct ChannelMessage {
     pub sender_id: UserId,
     pub body: String,
     pub sent_at: OffsetDateTime,
+    pub nonce: Uuid,
 }
 
 #[cfg(test)]
@@ -677,7 +681,7 @@ pub mod tests {
         let org = db.create_org("org", "org").await.unwrap();
         let channel = db.create_org_channel(org, "channel").await.unwrap();
         for i in 0..10 {
-            db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc())
+            db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
                 .await
                 .unwrap();
         }
@@ -697,4 +701,34 @@ pub mod tests {
             ["1", "2", "3", "4"]
         );
     }
+
+    #[gpui::test]
+    async fn test_channel_message_nonces() {
+        let test_db = TestDb::new();
+        let db = test_db.db();
+        let user = db.create_user("user", false).await.unwrap();
+        let org = db.create_org("org", "org").await.unwrap();
+        let channel = db.create_org_channel(org, "channel").await.unwrap();
+
+        let msg1_id = db
+            .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
+            .await
+            .unwrap();
+        let msg2_id = db
+            .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
+            .await
+            .unwrap();
+        let msg3_id = db
+            .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
+            .await
+            .unwrap();
+        let msg4_id = db
+            .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
+            .await
+            .unwrap();
+
+        assert_ne!(msg1_id, msg2_id);
+        assert_eq!(msg1_id, msg3_id);
+        assert_eq!(msg2_id, msg4_id);
+    }
 }

server/src/rpc.rs 🔗

@@ -602,6 +602,7 @@ impl Server {
                 body: msg.body,
                 timestamp: msg.sent_at.unix_timestamp() as u64,
                 sender_id: msg.sender_id.to_proto(),
+                nonce: Some(msg.nonce.as_u128().into()),
             })
             .collect::<Vec<_>>();
         self.peer
@@ -687,10 +688,24 @@ impl Server {
         }
 
         let timestamp = OffsetDateTime::now_utc();
+        let nonce = if let Some(nonce) = request.payload.nonce {
+            nonce
+        } else {
+            self.peer
+                .respond_with_error(
+                    receipt,
+                    proto::Error {
+                        message: "nonce can't be blank".to_string(),
+                    },
+                )
+                .await?;
+            return Ok(());
+        };
+
         let message_id = self
             .app_state
             .db
-            .create_channel_message(channel_id, user_id, &body, timestamp)
+            .create_channel_message(channel_id, user_id, &body, timestamp, nonce.clone().into())
             .await?
             .to_proto();
         let message = proto::ChannelMessage {
@@ -698,6 +713,7 @@ impl Server {
             id: message_id,
             body,
             timestamp: timestamp.unix_timestamp() as u64,
+            nonce: Some(nonce),
         };
         broadcast(request.sender_id, connection_ids, |conn_id| {
             self.peer.send(
@@ -754,6 +770,7 @@ impl Server {
                 body: msg.body,
                 timestamp: msg.sent_at.unix_timestamp() as u64,
                 sender_id: msg.sender_id.to_proto(),
+                nonce: Some(msg.nonce.as_u128().into()),
             })
             .collect::<Vec<_>>();
         self.peer
@@ -1513,6 +1530,7 @@ mod tests {
             current_user_id(&user_store_b),
             "hello A, it's B.",
             OffsetDateTime::now_utc(),
+            1,
         )
         .await
         .unwrap();
@@ -1707,6 +1725,7 @@ mod tests {
             current_user_id(&user_store_b),
             "hello A, it's B.",
             OffsetDateTime::now_utc(),
+            2,
         )
         .await
         .unwrap();
@@ -1787,6 +1806,24 @@ mod tests {
             )
         });
 
+        // Send a message from client B while it is disconnected.
+        channel_b
+            .update(&mut cx_b, |channel, cx| {
+                let task = channel
+                    .send_message("can you see this?".to_string(), cx)
+                    .unwrap();
+                assert_eq!(
+                    channel_messages(channel),
+                    &[
+                        ("user_b".to_string(), "hello A, it's B.".to_string(), false),
+                        ("user_b".to_string(), "can you see this?".to_string(), true)
+                    ]
+                );
+                task
+            })
+            .await
+            .unwrap_err();
+
         // Send a message from client A while B is disconnected.
         channel_a
             .update(&mut cx_a, |channel, cx| {
@@ -1812,7 +1849,8 @@ mod tests {
         server.allow_connections();
         cx_b.foreground().advance_clock(Duration::from_secs(10));
 
-        // Verify that B sees the new messages upon reconnection.
+        // Verify that B sees the new messages upon reconnection, as well as the message client B
+        // sent while offline.
         channel_b
             .condition(&cx_b, |channel, _| {
                 channel_messages(channel)
@@ -1820,6 +1858,7 @@ mod tests {
                         ("user_b".to_string(), "hello A, it's B.".to_string(), false),
                         ("user_a".to_string(), "oh, hi B.".to_string(), false),
                         ("user_a".to_string(), "sup".to_string(), false),
+                        ("user_b".to_string(), "can you see this?".to_string(), false),
                     ]
             })
             .await;
@@ -1838,6 +1877,7 @@ mod tests {
                         ("user_b".to_string(), "hello A, it's B.".to_string(), false),
                         ("user_a".to_string(), "oh, hi B.".to_string(), false),
                         ("user_a".to_string(), "sup".to_string(), false),
+                        ("user_b".to_string(), "can you see this?".to_string(), false),
                         ("user_a".to_string(), "you online?".to_string(), false),
                     ]
             })
@@ -1856,6 +1896,7 @@ mod tests {
                         ("user_b".to_string(), "hello A, it's B.".to_string(), false),
                         ("user_a".to_string(), "oh, hi B.".to_string(), false),
                         ("user_a".to_string(), "sup".to_string(), false),
+                        ("user_b".to_string(), "can you see this?".to_string(), false),
                         ("user_a".to_string(), "you online?".to_string(), false),
                         ("user_b".to_string(), "yep".to_string(), false),
                     ]

zed/src/channel.rs 🔗

@@ -9,6 +9,7 @@ use gpui::{
     Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle,
 };
 use postage::prelude::Stream;
+use rand::prelude::*;
 use std::{
     collections::{HashMap, HashSet},
     mem,
@@ -42,6 +43,7 @@ pub struct Channel {
     next_pending_message_id: usize,
     user_store: Arc<UserStore>,
     rpc: Arc<Client>,
+    rng: StdRng,
     _subscription: rpc::Subscription,
 }
 
@@ -51,6 +53,7 @@ pub struct ChannelMessage {
     pub body: String,
     pub timestamp: OffsetDateTime,
     pub sender: Arc<User>,
+    pub nonce: u128,
 }
 
 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
@@ -218,6 +221,7 @@ impl Channel {
             messages: Default::default(),
             loaded_all_messages: false,
             next_pending_message_id: 0,
+            rng: StdRng::from_entropy(),
             _subscription,
         }
     }
@@ -242,6 +246,7 @@ impl Channel {
 
         let channel_id = self.details.id;
         let pending_id = ChannelMessageId::Pending(post_inc(&mut self.next_pending_message_id));
+        let nonce = self.rng.gen();
         self.insert_messages(
             SumTree::from_item(
                 ChannelMessage {
@@ -249,6 +254,7 @@ impl Channel {
                     body: body.clone(),
                     sender: current_user,
                     timestamp: OffsetDateTime::now_utc(),
+                    nonce,
                 },
                 &(),
             ),
@@ -257,7 +263,11 @@ impl Channel {
         let user_store = self.user_store.clone();
         let rpc = self.rpc.clone();
         Ok(cx.spawn(|this, mut cx| async move {
-            let request = rpc.request(proto::SendChannelMessage { channel_id, body });
+            let request = rpc.request(proto::SendChannelMessage {
+                channel_id,
+                body,
+                nonce: Some(nonce.into()),
+            });
             let response = request.await?;
             let message = ChannelMessage::from_proto(
                 response.message.ok_or_else(|| anyhow!("invalid message"))?,
@@ -265,7 +275,6 @@ impl Channel {
             )
             .await?;
             this.update(&mut cx, |this, cx| {
-                this.remove_message(pending_id, cx);
                 this.insert_messages(SumTree::from_item(message, &()), cx);
                 Ok(())
             })
@@ -312,32 +321,51 @@ impl Channel {
         let user_store = self.user_store.clone();
         let rpc = self.rpc.clone();
         let channel_id = self.details.id;
-        cx.spawn(|channel, mut cx| {
+        cx.spawn(|this, mut cx| {
             async move {
                 let response = rpc.request(proto::JoinChannel { channel_id }).await?;
                 let messages = messages_from_proto(response.messages, &user_store).await?;
                 let loaded_all_messages = response.done;
 
-                channel.update(&mut cx, |channel, cx| {
+                let pending_messages = this.update(&mut cx, |this, cx| {
                     if let Some((first_new_message, last_old_message)) =
-                        messages.first().zip(channel.messages.last())
+                        messages.first().zip(this.messages.last())
                     {
                         if first_new_message.id > last_old_message.id {
-                            let old_messages = mem::take(&mut channel.messages);
+                            let old_messages = mem::take(&mut this.messages);
                             cx.emit(ChannelEvent::MessagesUpdated {
                                 old_range: 0..old_messages.summary().count,
                                 new_count: 0,
                             });
-                            channel.loaded_all_messages = loaded_all_messages;
+                            this.loaded_all_messages = loaded_all_messages;
                         }
                     }
 
-                    channel.insert_messages(messages, cx);
+                    this.insert_messages(messages, cx);
                     if loaded_all_messages {
-                        channel.loaded_all_messages = loaded_all_messages;
+                        this.loaded_all_messages = loaded_all_messages;
                     }
+
+                    this.pending_messages().cloned().collect::<Vec<_>>()
                 });
 
+                for pending_message in pending_messages {
+                    let request = rpc.request(proto::SendChannelMessage {
+                        channel_id,
+                        body: pending_message.body,
+                        nonce: Some(pending_message.nonce.into()),
+                    });
+                    let response = request.await?;
+                    let message = ChannelMessage::from_proto(
+                        response.message.ok_or_else(|| anyhow!("invalid message"))?,
+                        &user_store,
+                    )
+                    .await?;
+                    this.update(&mut cx, |this, cx| {
+                        this.insert_messages(SumTree::from_item(message, &()), cx);
+                    });
+                }
+
                 Ok(())
             }
             .log_err()
@@ -365,6 +393,12 @@ impl Channel {
         cursor.take(range.len())
     }
 
+    pub fn pending_messages(&self) -> impl Iterator<Item = &ChannelMessage> {
+        let mut cursor = self.messages.cursor::<ChannelMessageId, ()>();
+        cursor.seek(&ChannelMessageId::Pending(0), Bias::Left, &());
+        cursor
+    }
+
     fn handle_message_sent(
         &mut self,
         message: TypedEnvelope<ChannelMessageSent>,
@@ -391,29 +425,13 @@ impl Channel {
         Ok(())
     }
 
-    fn remove_message(&mut self, message_id: ChannelMessageId, cx: &mut ModelContext<Self>) {
-        let mut old_cursor = self.messages.cursor::<ChannelMessageId, Count>();
-        let mut new_messages = old_cursor.slice(&message_id, Bias::Left, &());
-        let start_ix = old_cursor.sum_start().0;
-        let removed_messages = old_cursor.slice(&message_id, Bias::Right, &());
-        let removed_count = removed_messages.summary().count;
-        new_messages.push_tree(old_cursor.suffix(&()), &());
-
-        drop(old_cursor);
-        self.messages = new_messages;
-
-        if removed_count > 0 {
-            let end_ix = start_ix + removed_count;
-            cx.emit(ChannelEvent::MessagesUpdated {
-                old_range: start_ix..end_ix,
-                new_count: 0,
-            });
-            cx.notify();
-        }
-    }
-
     fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
         if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
+            let nonces = messages
+                .cursor::<(), ()>()
+                .map(|m| m.nonce)
+                .collect::<HashSet<_>>();
+
             let mut old_cursor = self.messages.cursor::<ChannelMessageId, Count>();
             let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
             let start_ix = old_cursor.sum_start().0;
@@ -423,10 +441,40 @@ impl Channel {
             let end_ix = start_ix + removed_count;
 
             new_messages.push_tree(messages, &());
-            new_messages.push_tree(old_cursor.suffix(&()), &());
+
+            let mut ranges = Vec::<Range<usize>>::new();
+            if new_messages.last().unwrap().is_pending() {
+                new_messages.push_tree(old_cursor.suffix(&()), &());
+            } else {
+                new_messages.push_tree(
+                    old_cursor.slice(&ChannelMessageId::Pending(0), Bias::Left, &()),
+                    &(),
+                );
+
+                while let Some(message) = old_cursor.item() {
+                    let message_ix = old_cursor.sum_start().0;
+                    if nonces.contains(&message.nonce) {
+                        if ranges.last().map_or(false, |r| r.end == message_ix) {
+                            ranges.last_mut().unwrap().end += 1;
+                        } else {
+                            ranges.push(message_ix..message_ix + 1);
+                        }
+                    } else {
+                        new_messages.push(message.clone(), &());
+                    }
+                    old_cursor.next(&());
+                }
+            }
+
             drop(old_cursor);
             self.messages = new_messages;
 
+            for range in ranges.into_iter().rev() {
+                cx.emit(ChannelEvent::MessagesUpdated {
+                    old_range: range,
+                    new_count: 0,
+                });
+            }
             cx.emit(ChannelEvent::MessagesUpdated {
                 old_range: start_ix..end_ix,
                 new_count,
@@ -477,6 +525,10 @@ impl ChannelMessage {
             body: message.body,
             timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
             sender,
+            nonce: message
+                .nonce
+                .ok_or_else(|| anyhow!("nonce is required"))?
+                .into(),
         })
     }
 
@@ -606,12 +658,14 @@ mod tests {
                             body: "a".into(),
                             timestamp: 1000,
                             sender_id: 5,
+                            nonce: Some(1.into()),
                         },
                         proto::ChannelMessage {
                             id: 11,
                             body: "b".into(),
                             timestamp: 1001,
                             sender_id: 6,
+                            nonce: Some(2.into()),
                         },
                     ],
                     done: false,
@@ -665,6 +719,7 @@ mod tests {
                     body: "c".into(),
                     timestamp: 1002,
                     sender_id: 7,
+                    nonce: Some(3.into()),
                 }),
             })
             .await;
@@ -720,12 +775,14 @@ mod tests {
                             body: "y".into(),
                             timestamp: 998,
                             sender_id: 5,
+                            nonce: Some(4.into()),
                         },
                         proto::ChannelMessage {
                             id: 9,
                             body: "z".into(),
                             timestamp: 999,
                             sender_id: 6,
+                            nonce: Some(5.into()),
                         },
                     ],
                 },

zrpc/proto/zed.proto 🔗

@@ -151,6 +151,7 @@ message GetUsersResponse {
 message SendChannelMessage {
     uint64 channel_id = 1;
     string body = 2;
+    Nonce nonce = 3;
 }
 
 message SendChannelMessageResponse {
@@ -296,6 +297,11 @@ message Range {
     uint64 end = 2;
 }
 
+message Nonce {
+    uint64 upper_half = 1;
+    uint64 lower_half = 2;
+}
+
 message Channel {
     uint64 id = 1;
     string name = 2;
@@ -306,4 +312,5 @@ message ChannelMessage {
     string body = 2;
     uint64 timestamp = 3;
     uint64 sender_id = 4;
+    Nonce nonce = 5;
 }

zrpc/src/proto.rs 🔗

@@ -248,3 +248,22 @@ impl From<SystemTime> for Timestamp {
         }
     }
 }
+
+impl From<u128> for Nonce {
+    fn from(nonce: u128) -> Self {
+        let upper_half = (nonce >> 64) as u64;
+        let lower_half = nonce as u64;
+        Self {
+            upper_half,
+            lower_half,
+        }
+    }
+}
+
+impl From<Nonce> for u128 {
+    fn from(nonce: Nonce) -> Self {
+        let upper_half = (nonce.upper_half as u128) << 64;
+        let lower_half = nonce.lower_half as u128;
+        upper_half | lower_half
+    }
+}