Start work on sending channel messages

Max Brunsfeld created

Change summary

server/src/rpc.rs    | 16 +++++++-
server/src/tests.rs  | 44 +++++++++++++++--------
zed/src/channel.rs   | 84 +++++++++++++++++++++++++++++++++++++--------
zed/src/rpc.rs       | 23 ++++++++---
zrpc/proto/zed.proto |  8 +++
5 files changed, 133 insertions(+), 42 deletions(-)

Detailed changes

server/src/rpc.rs 🔗

@@ -619,12 +619,14 @@ impl Server {
             .app_state
             .db
             .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
-            .await?;
+            .await?
+            .to_proto();
+        let receipt = request.receipt();
         let message = proto::ChannelMessageSent {
             channel_id: channel_id.to_proto(),
             message: Some(proto::ChannelMessage {
                 sender_id: user_id.to_proto(),
-                id: message_id.to_proto(),
+                id: message_id,
                 body: request.payload.body,
                 timestamp: timestamp.unix_timestamp() as u64,
             }),
@@ -633,7 +635,15 @@ impl Server {
             self.peer.send(conn_id, message.clone())
         })
         .await?;
-
+        self.peer
+            .respond(
+                receipt,
+                proto::SendChannelMessageResponse {
+                    message_id,
+                    timestamp: timestamp.unix_timestamp() as u64,
+                },
+            )
+            .await?;
         Ok(())
     }
 

server/src/tests.rs 🔗

@@ -485,13 +485,11 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) {
     let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
     let (user_id_b, client_b) = server.create_client(&mut cx_a, "user_b").await;
 
-    // Create an org that includes these 2 users and 1 other user.
+    // Create an org that includes these 2 users.
     let db = &server.app_state.db;
-    let user_id_c = db.create_user("user_c", false).await.unwrap();
     let org_id = db.create_org("Test Org", "test-org").await.unwrap();
     db.add_org_member(org_id, user_id_a, false).await.unwrap();
     db.add_org_member(org_id, user_id_b, false).await.unwrap();
-    db.add_org_member(org_id, user_id_c, false).await.unwrap();
 
     // Create a channel that includes all the users.
     let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
@@ -501,13 +499,10 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) {
     db.add_channel_member(channel_id, user_id_b, false)
         .await
         .unwrap();
-    db.add_channel_member(channel_id, user_id_c, false)
-        .await
-        .unwrap();
     db.create_channel_message(
         channel_id,
-        user_id_c,
-        "first message!",
+        user_id_b,
+        "hello A, it's B.",
         OffsetDateTime::now_utc(),
     )
     .await
@@ -516,9 +511,6 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) {
     let channels_a = ChannelList::new(client_a, &mut cx_a.to_async())
         .await
         .unwrap();
-    let channels_b = ChannelList::new(client_b, &mut cx_b.to_async())
-        .await
-        .unwrap();
     channels_a.read_with(&cx_a, |list, _| {
         assert_eq!(
             list.available_channels(),
@@ -532,12 +524,33 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) {
     let channel_a = channels_a.update(&mut cx_a, |this, cx| {
         this.get_channel(channel_id.to_proto(), cx).unwrap()
     });
-
-    channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_none()));
+    channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
     channel_a.next_notification(&cx_a).await;
     channel_a.read_with(&cx_a, |channel, _| {
-        assert_eq!(channel.messages().unwrap().len(), 1);
+        assert_eq!(
+            channel
+                .messages()
+                .iter()
+                .map(|m| (m.sender_id, m.body.as_ref()))
+                .collect::<Vec<_>>(),
+            &[(user_id_b.to_proto(), "hello A, it's B.")]
+        );
     });
+
+    channel_a.update(&mut cx_a, |channel, cx| {
+        channel.send_message("oh, hi B.".to_string(), cx).unwrap();
+        channel.send_message("sup".to_string(), cx).unwrap();
+        assert_eq!(
+            channel
+                .pending_messages()
+                .iter()
+                .map(|m| &m.body)
+                .collect::<Vec<_>>(),
+            &["oh, hi B.", "sup"]
+        )
+    });
+
+    channel_a.next_notification(&cx_a).await;
 }
 
 struct TestServer {
@@ -577,10 +590,9 @@ impl TestServer {
             )
             .detach();
         client
-            .add_connection(client_conn, cx.to_async())
+            .add_connection(user_id.to_proto(), client_conn, cx.to_async())
             .await
             .unwrap();
-
         (user_id, client)
     }
 

zed/src/channel.rs 🔗

@@ -1,5 +1,8 @@
-use crate::rpc::{self, Client};
-use anyhow::{Context, Result};
+use crate::{
+    rpc::{self, Client},
+    util::log_async_errors,
+};
+use anyhow::{anyhow, Context, Result};
 use gpui::{
     AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, WeakModelHandle,
 };
@@ -27,14 +30,24 @@ pub struct ChannelDetails {
 pub struct Channel {
     details: ChannelDetails,
     first_message_id: Option<u64>,
-    messages: Option<VecDeque<ChannelMessage>>,
+    messages: VecDeque<ChannelMessage>,
+    pending_messages: Vec<PendingChannelMessage>,
+    next_local_message_id: u64,
     rpc: Arc<Client>,
     _subscription: rpc::Subscription,
 }
 
 pub struct ChannelMessage {
-    id: u64,
+    pub id: u64,
+    pub sender_id: u64,
+    pub body: String,
+}
+
+pub struct PendingChannelMessage {
+    pub body: String,
+    local_id: u64,
 }
+
 pub enum Event {}
 
 impl Entity for ChannelList {
@@ -110,13 +123,10 @@ impl Channel {
             let channel_id = details.id;
             cx.spawn(|channel, mut cx| async move {
                 match rpc.request(proto::JoinChannel { channel_id }).await {
-                    Ok(response) => {
-                        let messages = response.messages.into_iter().map(Into::into).collect();
-                        channel.update(&mut cx, |channel, cx| {
-                            channel.messages = Some(messages);
-                            cx.notify();
-                        })
-                    }
+                    Ok(response) => channel.update(&mut cx, |channel, cx| {
+                        channel.messages = response.messages.into_iter().map(Into::into).collect();
+                        cx.notify();
+                    }),
                     Err(error) => log::error!("error joining channel: {}", error),
                 }
             })
@@ -127,14 +137,54 @@ impl Channel {
             details,
             rpc,
             first_message_id: None,
-            messages: None,
+            messages: Default::default(),
+            pending_messages: Default::default(),
+            next_local_message_id: 0,
             _subscription,
         }
     }
 
+    pub fn send_message(&mut self, body: String, cx: &mut ModelContext<Self>) -> Result<()> {
+        let channel_id = self.details.id;
+        let current_user_id = self.rpc.user_id().ok_or_else(|| anyhow!("not logged in"))?;
+        let local_id = self.next_local_message_id;
+        self.next_local_message_id += 1;
+        self.pending_messages.push(PendingChannelMessage {
+            local_id,
+            body: body.clone(),
+        });
+        let rpc = self.rpc.clone();
+        cx.spawn(|this, mut cx| {
+            log_async_errors(async move {
+                let request = rpc.request(proto::SendChannelMessage { channel_id, body });
+                let response = request.await?;
+                this.update(&mut cx, |this, cx| {
+                    if let Ok(i) = this
+                        .pending_messages
+                        .binary_search_by_key(&local_id, |msg| msg.local_id)
+                    {
+                        let body = this.pending_messages.remove(i).body;
+                        this.messages.push_back(ChannelMessage {
+                            id: response.message_id,
+                            sender_id: current_user_id,
+                            body,
+                        });
+                        cx.notify();
+                    }
+                });
+                Ok(())
+            })
+        })
+        .detach();
+        Ok(())
+    }
+
+    pub fn messages(&self) -> &VecDeque<ChannelMessage> {
+        &self.messages
+    }
 
-    pub fn messages(&self) -> Option<&VecDeque<ChannelMessage>> {
-        self.messages.as_ref()
+    pub fn pending_messages(&self) -> &[PendingChannelMessage] {
+        &self.pending_messages
     }
 
     fn handle_message_sent(
@@ -158,6 +208,10 @@ impl From<proto::Channel> for ChannelDetails {
 
 impl From<proto::ChannelMessage> for ChannelMessage {
     fn from(message: proto::ChannelMessage) -> Self {
-        ChannelMessage { id: message.id }
+        ChannelMessage {
+            id: message.id,
+            sender_id: message.sender_id,
+            body: message.body,
+        }
     }
 }

zed/src/rpc.rs 🔗

@@ -31,6 +31,7 @@ pub struct Client {
 #[derive(Default)]
 struct ClientState {
     connection_id: Option<ConnectionId>,
+    user_id: Option<u64>,
     entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
     model_handlers: HashMap<
         (TypeId, u64),
@@ -66,6 +67,10 @@ impl Client {
         })
     }
 
+    pub fn user_id(&self) -> Option<u64> {
+        self.state.read().user_id
+    }
+
     pub fn subscribe_from_model<T, M, F>(
         self: &Arc<Self>,
         remote_id: u64,
@@ -125,7 +130,7 @@ impl Client {
         }
 
         let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
-        let user_id: i32 = user_id.parse()?;
+        let user_id = user_id.parse::<u64>()?;
         let request =
             Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
 
@@ -135,23 +140,25 @@ impl Client {
             let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
                 .await
                 .context("websocket handshake")?;
-            log::info!("connected to rpc address {}", *ZED_SERVER_URL);
-            self.add_connection(stream, cx).await?;
+            self.add_connection(user_id, stream, cx).await?;
         } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
             let stream = smol::net::TcpStream::connect(host).await?;
             let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
-            let (stream, _) = async_tungstenite::client_async(request, stream).await?;
-            log::info!("connected to rpc address {}", *ZED_SERVER_URL);
-            self.add_connection(stream, cx).await?;
+            let (stream, _) = async_tungstenite::client_async(request, stream)
+                .await
+                .context("websocket handshake")?;
+            self.add_connection(user_id, stream, cx).await?;
         } else {
             return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
         };
 
+        log::info!("connected to rpc address {}", *ZED_SERVER_URL);
         Ok(())
     }
 
     pub async fn add_connection<Conn>(
         self: &Arc<Self>,
+        user_id: u64,
         conn: Conn,
         cx: AsyncAppContext,
     ) -> surf::Result<()>
@@ -202,7 +209,9 @@ impl Client {
                 }
             })
             .detach();
-        self.state.write().connection_id = Some(connection_id);
+        let mut state = self.state.write();
+        state.connection_id = Some(connection_id);
+        state.user_id = Some(user_id);
         Ok(())
     }
 

zrpc/proto/zed.proto 🔗

@@ -30,7 +30,8 @@ message Envelope {
         JoinChannelResponse join_channel_response = 25;
         LeaveChannel leave_channel = 26;
         SendChannelMessage send_channel_message = 27;
-        ChannelMessageSent channel_message_sent = 28;
+        SendChannelMessageResponse send_channel_message_response = 28;
+        ChannelMessageSent channel_message_sent = 29;
     }
 }
 
@@ -148,6 +149,11 @@ message SendChannelMessage {
     string body = 2;
 }
 
+message SendChannelMessageResponse {
+    uint64 message_id = 1;
+    uint64 timestamp = 2;
+}
+
 message ChannelMessageSent {
     uint64 channel_id = 1;
     ChannelMessage message = 2;